Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Join style queries for dict, set, list, and embedded fields #177

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Contributions by
* Sabin Iacob (https://github.com/m0n5t3r)
* kryton (https://github.com/kryton)
* Brandon Pedersen (https://github.com/bpedman)
* Brian Gontowski (https://github.com/Molanda)

(For an up-to-date list of contributors, see
https://github.com/django-mongodb-engine/mongodb-engine/contributors.)
3 changes: 3 additions & 0 deletions django_mongodb_engine/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@ def insert(self, docs, return_id=False):
doc.clear()
else:
raise DatabaseError("Can't save entity with _id set to None")
for d in doc.keys():
if '.' in d:
del doc[d]

collection = self.get_collection()
options = self.connection.operation_flags.get('save', {})
Expand Down
50 changes: 50 additions & 0 deletions django_mongodb_engine/contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import sys
import re

from django.db import models, connections
from django.db.models.query import QuerySet
from django.db.models.sql.query import Query as SQLQuery
from django.db.models.query_utils import Q
from django.db.models.constants import LOOKUP_SEP
from django_mongodb_engine.compiler import OPERATORS_MAP, NEGATED_OPERATORS_MAP
from djangotoolbox.fields import AbstractIterableField


ON_PYPY = hasattr(sys, 'pypy_version_info')
ALL_OPERATORS = dict(list(OPERATORS_MAP.items() + NEGATED_OPERATORS_MAP.items())).keys()
MONGO_DOT_FIELDS = ('DictField', 'ListField', 'SetField', 'EmbeddedModelField')


def _compiler_for_queryset(qs, which='SQLCompiler'):
Expand Down Expand Up @@ -84,6 +91,49 @@ def __repr__(self):


class MongoDBQuerySet(QuerySet):
def _filter_or_exclude(self, negate, *args, **kwargs):
if args or kwargs:
assert self.query.can_filter(), \
"Cannot filter a query once a slice has been taken."

clone = self._clone()

all_field_names = self.model._meta.get_all_field_names()
base_field_names = []

for name in all_field_names:
field = self.model._meta.get_field_by_name(name)[0]
if '.' not in name and field.get_internal_type() in MONGO_DOT_FIELDS:
base_field_names.append(name)

for key, val in kwargs.items():
if LOOKUP_SEP in key and key.split(LOOKUP_SEP)[0] in base_field_names:
del kwargs[key]
for op in ALL_OPERATORS:
if key.endswith(op):
key = re.sub(LOOKUP_SEP + op + '$', '#' + op, key)
break
key = key.replace(LOOKUP_SEP, '.').replace('#', LOOKUP_SEP)
kwargs[key] = val
name = key.split(LOOKUP_SEP)[0]
if '.' in name and name not in all_field_names:
parts = name.split('.')
column = self.model._meta.get_field_by_name(parts[0])[0].db_column
if column:
parts[0] = column
field = AbstractIterableField(
db_column = '.'.join(parts),
blank=True,
null=True,
editable=False,
)
field.contribute_to_class(self.model, name)

if negate:
clone.query.add_q(~Q(*args, **kwargs))
else:
clone.query.add_q(Q(*args, **kwargs))
return clone

def map_reduce(self, *args, **kwargs):
"""
Expand Down
Empty file added tests/dotquery/__init__.py
Empty file.
17 changes: 17 additions & 0 deletions tests/dotquery/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from django.db import models
from djangotoolbox.fields import ListField, DictField, EmbeddedModelField
from django_mongodb_engine.contrib import MongoDBManager


class DotQueryEmbeddedModel(models.Model):
f_int = models.IntegerField()


class DotQueryTestModel(models.Model):
objects = MongoDBManager()

f_id = models.IntegerField()
f_dict = DictField(db_column='test_dict')
f_list = ListField()
f_embedded = EmbeddedModelField(DotQueryEmbeddedModel)
f_embedded_list = ListField(EmbeddedModelField(DotQueryEmbeddedModel))
81 changes: 81 additions & 0 deletions tests/dotquery/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import with_statement
from models import *
from utils import *


class DotQueryTests(TestCase):
"""Tests for querying on foo.bar using join syntax."""

def setUp(self):
DotQueryTestModel.objects.create(
f_id=51,
f_dict={'numbers': [1, 2, 3], 'letters': 'abc'},
f_list=[{'color': 'red'}, {'color': 'blue'}],
f_embedded=DotQueryEmbeddedModel(f_int=10),
f_embedded_list=[
DotQueryEmbeddedModel(f_int=100),
DotQueryEmbeddedModel(f_int=101),
],
)
DotQueryTestModel.objects.create(
f_id=52,
f_dict={'numbers': [2, 3], 'letters': 'bc'},
f_list=[{'color': 'red'}, {'color': 'green'}],
f_embedded=DotQueryEmbeddedModel(f_int=11),
f_embedded_list=[
DotQueryEmbeddedModel(f_int=110),
DotQueryEmbeddedModel(f_int=111),
],
)
DotQueryTestModel.objects.create(
f_id=53,
f_dict={'numbers': [3, 4], 'letters': 'cd'},
f_list=[{'color': 'yellow'}, {'color': 'orange'}],
f_embedded=DotQueryEmbeddedModel(f_int=12),
f_embedded_list=[
DotQueryEmbeddedModel(f_int=120),
DotQueryEmbeddedModel(f_int=121),
],
)

def tearDown(self):
DotQueryTestModel.objects.all().delete()

def test_dict_queries(self):
q = DotQueryTestModel.objects.filter(f_dict__numbers=2)
self.assertEqual(q.count(), 2)
q = DotQueryTestModel.objects.filter(f_dict__letters__contains='b')
self.assertEqual(q.count(), 2)
q = DotQueryTestModel.objects.exclude(f_dict__letters__contains='b')
self.assertEqual(q.count(), 1)
self.assertEqual(q[0].f_id, 53)

def test_list_queries(self):
q = DotQueryTestModel.objects.filter(f_list__color='red')
q = q.exclude(f_list__color='green')
q = q.exclude(f_list__color='purple')
self.assertEqual(q.count(), 1)
self.assertEqual(q[0].f_id, 51)

def test_embedded_queries(self):
q = DotQueryTestModel.objects.exclude(f_embedded__f_int__in=[10, 12])
self.assertEqual(q.count(), 1)
self.assertEqual(q[0].f_id, 52)

def test_embedded_list_queries(self):
q = DotQueryTestModel.objects.get(f_embedded_list__f_int=120)
self.assertEqual(q.f_id, 53)

def test_save_after_query(self):
q = DotQueryTestModel.objects.get(f_dict__letters='cd')
self.assertEqual(q.f_id, 53)
q.f_id = 1053
q.clean()
q.save()
q = DotQueryTestModel.objects.get(f_dict__letters='cd')
self.assertEqual(q.f_id, 1053)
q.f_id = 53
q.clean()
q.save()
q = DotQueryTestModel.objects.get(f_dict__letters='cd')
self.assertEqual(q.f_id, 53)
35 changes: 35 additions & 0 deletions tests/dotquery/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from django.conf import settings
from django.db import connections
from django.db.models import Model
from django.test import TestCase
from django.utils.unittest import skip


class TestCase(TestCase):

def setUp(self):
super(TestCase, self).setUp()
if getattr(settings, 'TEST_DEBUG', False):
settings.DEBUG = True

def assertEqualLists(self, a, b):
self.assertEqual(list(a), list(b))


def skip_all_except(*tests):

class meta(type):

def __new__(cls, name, bases, dict):
for attr in dict.keys():
if attr.startswith('test_') and attr not in tests:
del dict[attr]
return type.__new__(cls, name, bases, dict)

return meta


def get_collection(model_or_name):
if isinstance(model_or_name, type) and issubclass(model_or_name, Model):
model_or_name = model_or_name._meta.db_table
return connections['default'].get_collection(model_or_name)
1 change: 1 addition & 0 deletions tests/settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
'aggregations',
'contrib',
'storage',
'dotquery',
]