Skip to content

Commit

Permalink
Merge pull request #2069 from ResearchHub/fix-discussion-count-hacks
Browse files Browse the repository at this point in the history
Fix method for calculating discussion count
  • Loading branch information
koutst authored Jan 13, 2025
2 parents 650c9a4 + 62715fd commit 2b1a7de
Show file tree
Hide file tree
Showing 11 changed files with 383 additions and 137 deletions.
34 changes: 1 addition & 33 deletions src/paper/related_models/paper_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,39 +673,7 @@ def get_promoted_score(paper):
return False

def get_discussion_count(self):
from discussion.models import Thread

sources = [Thread.RESEARCHHUB, Thread.INLINE_ABSTRACT, Thread.INLINE_PAPER_BODY]

thread_count = self.threads.aggregate(
discussion_count=Count(
1,
filter=Q(
is_removed=False, created_by__isnull=False, source__in=sources
),
)
)["discussion_count"]
comment_count = self.threads.aggregate(
discussion_count=Count(
"comments",
filter=Q(
comments__is_removed=False,
comments__created_by__isnull=False,
source__in=sources,
),
)
)["discussion_count"]
reply_count = self.threads.aggregate(
discussion_count=Count(
"comments__replies",
filter=Q(
comments__replies__is_removed=False,
comments__replies__created_by__isnull=False,
source__in=sources,
),
)
)["discussion_count"]
return thread_count + comment_count + reply_count
return self.rh_threads.get_discussion_count()

def extract_pdf_preview(self, use_celery=True):
if TESTING:
Expand Down
4 changes: 1 addition & 3 deletions src/paper/serializers/paper_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,9 +962,7 @@ def get_discussions(self, paper):
return serializer.data

def get_discussion_aggregates(self, paper):
aggregates = paper.rh_threads.get_discussion_aggregates()
aggregates["discussion_count"] = paper.discussion_count
return aggregates
return paper.rh_threads.get_discussion_aggregates(paper)

def get_hubs(self, paper):
context = self.context
Expand Down
14 changes: 14 additions & 0 deletions src/researchhub_comment/related_models/rh_comment_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,20 @@ def _update_related_discussion_count(self, amount):
related_document.discussion_count += amount
related_document.save(update_fields=["discussion_count"])

def refresh_related_discussion_count(self):
from citation.models import CitationEntry

thread = self.thread
if isinstance(self.thread.content_object, CitationEntry):
return

related_document = thread.unified_document.get_document()

if hasattr(related_document, "discussion_count"):
related_document.discussion_count = related_document.get_discussion_count()

related_document.save(update_fields=["discussion_count"])

def increment_discussion_count(self):
self._update_related_discussion_count(1)

Expand Down
94 changes: 57 additions & 37 deletions src/researchhub_comment/related_models/rh_comment_thread_model.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,59 @@
from django.contrib.contenttypes.fields import GenericRelation
from django.db import models
from django.db import connection, models
from django.db.models import CharField, Count, JSONField, Q
from django.db.models.functions import Cast

from researchhub_access_group.models import Permission
from researchhub_comment.constants.rh_comment_thread_types import (
GENERIC_COMMENT,
INNER_CONTENT_COMMENT,
PEER_REVIEW,
RH_COMMENT_THREAD_TYPES,
SUMMARY,
)
from utils.models import AbstractGenericRelationModel

"""
NOTE: RhCommentThreadModel's generic relation convention is to
- setup relations through AbstractGenericRelationModel
- add an edge named `rh_threads` for inverse reference on top of a target [content] model
- (see Hypothesis Model for reference)

This allows queries such as [ContentModel].rh_threads[...]
where [ContentModels] may be found in method "get_valid_target_content_model"
"""
class RhCommentThreadQuerySet(models.QuerySet):
def get_discussion_count(self):
"""
Uses a recursive CTE to count all comments (including replies)
ignoring removed comments, for all threads in this QuerySet.
"""
thread_ids = list(self.values_list("id", flat=True))
if not thread_ids:
return 0

query = """
WITH RECURSIVE comment_tree AS (
SELECT c.id, c.parent_id, 1 AS comment_count
FROM researchhub_comment_rhcommentmodel c
JOIN researchhub_comment_rhcommentthreadmodel t ON c.thread_id = t.id
WHERE c.parent_id IS NULL
AND c.is_removed = FALSE
AND t.id IN (SELECT UNNEST(%s))
class RhCommentThreadManager(models.Manager):
def get_discussion_aggregates(self):
return self.exclude(rh_comments__bounties__isnull=False).aggregate(
discussion_count=Count(
"rh_comments",
filter=(
Q(thread_type=INNER_CONTENT_COMMENT)
| Q(thread_type=GENERIC_COMMENT)
)
& Q(
rh_comments__is_removed=False,
rh_comments__bounties__isnull=True,
rh_comments__parent__bounties__isnull=True,
),
),
UNION ALL
SELECT c.id, c.parent_id, 1
FROM researchhub_comment_rhcommentmodel c
JOIN comment_tree ct ON c.parent_id = ct.id
WHERE c.is_removed = FALSE
)
SELECT COALESCE(SUM(comment_count), 0)::int AS total_count
FROM comment_tree;
"""

with connection.cursor() as cursor:
cursor.execute(query, [thread_ids])
result = cursor.fetchone()
return result[0] if result else 0

def get_discussion_aggregates(self, item):
"""
Example aggregator, adapted from your code.
Note: self.exclude(...) etc. uses the QuerySet instead of Manager.
"""
aggregator = self.exclude(rh_comments__bounties__isnull=False).aggregate(
review_count=Count(
"rh_comments",
filter=Q(
Expand All @@ -58,34 +74,38 @@ def get_discussion_aggregates(self):
),
)

aggregator["discussion_count"] = item.discussion_count
return aggregator

class RhCommentThreadModel(AbstractGenericRelationModel):
"""--- MODEL FIELDS ---"""

class RhCommentThreadManager(models.Manager):
def get_queryset(self):
return RhCommentThreadQuerySet(self.model, using=self._db)

def get_discussion_count(self):
return self.get_queryset().get_discussion_count()

def get_discussion_aggregates(self, item):
return self.get_queryset().get_discussion_aggregates(item)


class RhCommentThreadModel(AbstractGenericRelationModel):
thread_type = CharField(
max_length=144,
choices=RH_COMMENT_THREAD_TYPES,
default=GENERIC_COMMENT,
)
thread_reference = CharField(
blank=True,
help_text="""A thread may need a special referencing tool. Use this field for such a case""",
help_text="A thread may need a special referencing tool. Use this field for such a case",
max_length=144,
null=True,
)
anchor = JSONField(blank=True, null=True)
permissions = GenericRelation(
Permission,
related_name="rh_thread",
)
permissions = GenericRelation(Permission, related_name="rh_thread")

"""--- OBJECT MANAGER ---"""
objects = RhCommentThreadManager()

""" --- PROPERTIES --- """

@property
def unified_document(self):
return self.content_object.unified_document

"""--- METHODS ---"""
145 changes: 143 additions & 2 deletions src/researchhub_comment/tests/test_comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _create_post_comment(
self.client.force_authenticate(created_by)

res = self._create_comment(
"post",
"researchhubpost",
post_id,
created_by,
{
Expand Down Expand Up @@ -262,7 +262,7 @@ def test_do_not_notify_unqualified_users_about_bounty(self):

self.assertEqual(notification.exists(), False)

def test_censor_comment_updates_discussion_count(self):
def test_censor_paper_comments_updates_discussion_count(self):
# Create a parent comment with multiple children
parent_comment = self._create_paper_comment(self.paper.id, self.user_1)
child1 = self._create_paper_comment(
Expand Down Expand Up @@ -346,3 +346,144 @@ def test_censor_nested_comments(self):
# Should reduce count by 2 (child2 and child3)
paper_res = self.client.get(f"/api/paper/{self.paper.id}/")
self.assertEqual(paper_res.data["discussion_count"], 2)

def test_censor_post_comments_updates_discussion_count(self):
# Create a post first
self.client.force_authenticate(self.user_1)
post_res = self.client.post(
"/api/researchhubpost/",
{
"title": "Test Post needs to be 20 characters long",
"content_json": {
"ops": [
{
"insert": "Test content needs to be 50 characters long, minimum."
}
]
},
"document_type": "DISCUSSION",
"full_src": "Test content needs to be 50 characters long, minimum.",
"renderable_text": "Test content needs to be 50 characters long, minimum.",
},
)
post_id = post_res.data["id"]

# Create a parent comment with multiple children
parent_comment = self._create_post_comment(post_id, self.user_1)
child1 = self._create_post_comment(
post_id, self.user_2, parent_id=parent_comment.data["id"]
)
child2 = self._create_post_comment(
post_id, self.user_3, parent_id=parent_comment.data["id"]
)

# Verify initial discussion count
post_res = self.client.get(f"/api/researchhubpost/{post_id}/")
initial_count = post_res.data["discussion_count"]
self.assertEqual(initial_count, 3)

# Censor parent comment
self.client.force_authenticate(self.moderator)
censor_res = self.client.delete(
f"/api/researchhubpost/{post_id}/comments/{parent_comment.data['id']}/censor/"
)

# Verify discussion count was reduced by 3 (parent + 2 children)
post_res = self.client.get(f"/api/researchhubpost/{post_id}/")
self.assertEqual(post_res.data["discussion_count"], 0)

def test_censor_post_child_with_deleted_parent_preserves_count(self):
# Create a post
self.client.force_authenticate(self.user_1)
post_res = self.client.post(
"/api/researchhubpost/",
{
"title": "Test Post needs to be 20 characters long",
"content_json": {
"ops": [
{
"insert": "Test content needs to be 50 characters long, minimum."
}
]
},
"document_type": "DISCUSSION",
"full_src": "Test content needs to be 50 characters long, minimum.",
"renderable_text": "Test content needs to be 50 characters long, minimum.",
},
)
post_id = post_res.data["id"]

# Create parent and child comments
parent_comment = self._create_post_comment(post_id, self.user_1)
child_comment = self._create_post_comment(
post_id, self.user_2, parent_id=parent_comment.data["id"]
)

# Delete parent first
self.client.force_authenticate(self.user_1)
self.client.delete(
f"/api/researchhubpost/{post_id}/comments/{parent_comment.data['id']}/"
)

# Verify count after parent deletion
post_res = self.client.get(f"/api/researchhubpost/{post_id}/")
count_after_parent_delete = post_res.data["discussion_count"]

# Censor child comment
self.client.force_authenticate(self.moderator)
self.client.delete(
f"/api/researchhubpost/{post_id}/comments/{child_comment.data['id']}/censor/"
)

# Verify final count
post_res = self.client.get(f"/api/researchhubpost/{post_id}/")
self.assertEqual(
post_res.data["discussion_count"], count_after_parent_delete - 1
)

def test_censor_nested_post_comments(self):
# Create a post
self.client.force_authenticate(self.user_1)
post_res = self.client.post(
"/api/researchhubpost/",
{
"title": "Test Post needs to be 20 characters long",
"content_json": {
"ops": [
{
"insert": "Test content needs to be 50 characters long, minimum."
}
]
},
"document_type": "DISCUSSION",
"full_src": "Test content needs to be 50 characters long, minimum.",
"renderable_text": "Test content needs to be 50 characters long, minimum.",
},
)
post_id = post_res.data["id"]

# Create nested comment structure
parent = self._create_post_comment(post_id, self.user_1)
child1 = self._create_post_comment(
post_id, self.user_2, parent_id=parent.data["id"]
)
grandchild1 = self._create_post_comment(
post_id, self.user_3, parent_id=child1.data["id"]
)
child2 = self._create_post_comment(
post_id, self.user_2, parent_id=parent.data["id"]
)

# Verify initial count
post_res = self.client.get(f"/api/researchhubpost/{post_id}/")
self.assertEqual(post_res.data["discussion_count"], 4)

# Censor child1 (which should also censor grandchild1)
self.client.force_authenticate(self.moderator)
self.client.delete(
f"/api/researchhubpost/{post_id}/comments/{child1.data['id']}/censor/"
)

# Verify count after censoring (should decrease by 2)
post_res = self.client.get(f"/api/researchhubpost/{post_id}/")
self.assertEqual(post_res.data["discussion_count"], 2)
Loading

0 comments on commit 2b1a7de

Please sign in to comment.