Skip to content

Commit

Permalink
994 Add support for more functions (upper, lower etc) (#996)
Browse files Browse the repository at this point in the history
* support upper and lower

* remove prototype code

* rename `operator` to `outer_function`

* remove check for `call_chain`

* allow multiple outer_functions to be passed in

* update docstring

* use querystrings instead

* fix some linter issues

* fix linter errors

* add album table

* fix `load_json` on joined tables

* move logic for `get_select_string` from `Function` to `QueryString`

* use columns in querystring for joins

* add `not_in` to `querystring`

* add `get_where_string`

* set `_alias` in querystring __init__

* refactor `table_alias`

* move functions into a new folder

* re-export `Upper` and `Lower`

* add ltrim and rtrim functions

* add more functions

* improve error message

* add default value for `getattr` when fetching querystring columns

* add initial tests

* add a test for alias

* deprecate `Unquoted` - `QueryString` can be used directly

* simplify alias handling for `Function`

* don't get alias from child `QueryString`

* add `Reverse` function

* add `TestNested`

* fix sqlite tests

* improve tracking of columns within querystrings

* increase test timeouts

* add missing imports

* improve functions nested within `QueryString`

* refactor aggregate functions to use new format

* make sure where clauses work with functions

* fix linter errors

* update docs
  • Loading branch information
dantownsend authored May 29, 2024
1 parent da92fbe commit a85b404
Show file tree
Hide file tree
Showing 25 changed files with 660 additions and 367 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ on:
jobs:
linters:
runs-on: ubuntu-latest
timeout-minutes: 30
timeout-minutes: 60
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
Expand All @@ -35,7 +35,7 @@ jobs:

integration:
runs-on: ubuntu-latest
timeout-minutes: 30
timeout-minutes: 60
strategy:
matrix:
# These tests are slow, so we only run on the latest Python
Expand Down Expand Up @@ -82,7 +82,7 @@ jobs:

postgres:
runs-on: ubuntu-latest
timeout-minutes: 30
timeout-minutes: 60
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
Expand Down Expand Up @@ -138,7 +138,7 @@ jobs:

cockroach:
runs-on: ubuntu-latest
timeout-minutes: 30
timeout-minutes: 60
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
Expand Down Expand Up @@ -172,7 +172,7 @@ jobs:

sqlite:
runs-on: ubuntu-latest
timeout-minutes: 30
timeout-minutes: 60
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
Expand Down
4 changes: 2 additions & 2 deletions docs/src/piccolo/query_clauses/group_by.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ In the following query, we get a count of the number of bands per manager:

.. code-block:: python
>>> from piccolo.query.methods.select import Count
>>> from piccolo.query.functions.aggregate import Count
>>> await Band.select(
... Band.manager.name.as_alias('manager_name'),
... Count(alias='band_count')
... ).group_by(
... Band.manager
... Band.manager.name
... )
[
Expand Down
2 changes: 1 addition & 1 deletion docs/src/piccolo/query_types/count.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ It's equivalent to this ``select`` query:

.. code-block:: python
from piccolo.query.methods.select import Count
from piccolo.query.functions.aggregate import Count
>>> response = await Band.select(Count())
>>> response[0]['count']
Expand Down
37 changes: 31 additions & 6 deletions docs/src/piccolo/query_types/select.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,31 @@ convenient.
-------------------------------------------------------------------------------

String functions
----------------

Piccolo has lots of string functions built-in. See
``piccolo/query/functions/string.py``. Here's an example using ``Upper``, to
convert values to uppercase:

.. code-block:: python
from piccolo.query.functions.string import Upper
>> await Band.select(Upper(Band.name, alias='name'))
[{'name': 'PYTHONISTAS'}, ...]
You can also use these within where clauses:

.. code-block:: python
from piccolo.query.functions.string import Upper
>> await Band.select(Band.name).where(Upper(Band.manager.name) == 'GUIDO')
[{'name': 'Pythonistas'}]
-------------------------------------------------------------------------------

.. _AggregateFunctions:

Aggregate functions
Expand All @@ -182,7 +207,7 @@ Returns the number of matching rows.

.. code-block:: python
from piccolo.query.methods.select import Count
from piccolo.query.functions.aggregate import Count
>> await Band.select(Count()).where(Band.popularity > 100)
[{'count': 3}]
Expand All @@ -196,7 +221,7 @@ Returns the average for a given column:

.. code-block:: python
>>> from piccolo.query import Avg
>>> from piccolo.query.functions.aggregate import Avg
>>> response = await Band.select(Avg(Band.popularity)).first()
>>> response["avg"]
750.0
Expand All @@ -208,7 +233,7 @@ Returns the sum for a given column:

.. code-block:: python
>>> from piccolo.query import Sum
>>> from piccolo.query.functions.aggregate import Sum
>>> response = await Band.select(Sum(Band.popularity)).first()
>>> response["sum"]
1500
Expand All @@ -220,7 +245,7 @@ Returns the maximum for a given column:

.. code-block:: python
>>> from piccolo.query import Max
>>> from piccolo.query.functions.aggregate import Max
>>> response = await Band.select(Max(Band.popularity)).first()
>>> response["max"]
1000
Expand All @@ -232,7 +257,7 @@ Returns the minimum for a given column:

.. code-block:: python
>>> from piccolo.query import Min
>>> from piccolo.query.functions.aggregate import Min
>>> response = await Band.select(Min(Band.popularity)).first()
>>> response["min"]
500
Expand All @@ -244,7 +269,7 @@ You also can have multiple different aggregate functions in one query:

.. code-block:: python
>>> from piccolo.query import Avg, Sum
>>> from piccolo.query.functions.aggregate import Avg, Sum
>>> response = await Band.select(
... Avg(Band.popularity),
... Sum(Band.popularity)
Expand Down
71 changes: 31 additions & 40 deletions piccolo/columns/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import inspect
import typing as t
import uuid
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field, fields
from enum import Enum

Expand All @@ -32,6 +31,7 @@
NotLike,
)
from piccolo.columns.reference import LazyTableReference
from piccolo.querystring import QueryString, Selectable
from piccolo.utils.warnings import colored_warning

if t.TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -205,7 +205,6 @@ def table(self) -> t.Type[Table]:

# Used by Foreign Keys:
call_chain: t.List["ForeignKey"] = field(default_factory=list)
table_alias: t.Optional[str] = None

###########################################################################

Expand Down Expand Up @@ -260,7 +259,7 @@ def _get_path(self, include_quotes: bool = False):
column_name = self.db_column_name

if self.call_chain:
table_alias = self.call_chain[-1]._meta.table_alias
table_alias = self.call_chain[-1].table_alias
if include_quotes:
return f'"{table_alias}"."{column_name}"'
else:
Expand All @@ -272,7 +271,9 @@ def _get_path(self, include_quotes: bool = False):
return f"{self.table._meta.tablename}.{column_name}"

def get_full_name(
self, with_alias: bool = True, include_quotes: bool = True
self,
with_alias: bool = True,
include_quotes: bool = True,
) -> str:
"""
Returns the full column name, taking into account joins.
Expand Down Expand Up @@ -302,11 +303,10 @@ def get_full_name(
>>> column._meta.get_full_name(include_quotes=False)
'my_table_name.my_column_name'
"""
full_name = self._get_path(include_quotes=include_quotes)

if with_alias and self.call_chain:
if with_alias:
alias = self.get_default_alias()
if include_quotes:
full_name += f' AS "{alias}"'
Expand Down Expand Up @@ -346,32 +346,6 @@ def __deepcopy__(self, memo) -> ColumnMeta:
return self.copy()


class Selectable(metaclass=ABCMeta):
"""
Anything which inherits from this can be used in a select query.
"""

_alias: t.Optional[str]

@abstractmethod
def get_select_string(
self, engine_type: str, with_alias: bool = True
) -> str:
"""
In a query, what to output after the select statement - could be a
column name, a sub query, a function etc. For a column it will be the
column name.
"""
raise NotImplementedError()

def as_alias(self, alias: str) -> Selectable:
"""
Allows column names to be changed in the result of a select.
"""
self._alias = alias
return self


class Column(Selectable):
"""
All other columns inherit from ``Column``. Don't use it directly.
Expand Down Expand Up @@ -822,25 +796,32 @@ def get_default_value(self) -> t.Any:

def get_select_string(
self, engine_type: str, with_alias: bool = True
) -> str:
) -> QueryString:
"""
How to refer to this column in a SQL query, taking account of any joins
and aliases.
"""

if with_alias:
if self._alias:
original_name = self._meta.get_full_name(
with_alias=False,
)
return f'{original_name} AS "{self._alias}"'
return QueryString(f'{original_name} AS "{self._alias}"')
else:
return self._meta.get_full_name(
with_alias=True,
return QueryString(
self._meta.get_full_name(
with_alias=True,
)
)

return self._meta.get_full_name(with_alias=False)
return QueryString(
self._meta.get_full_name(
with_alias=False,
)
)

def get_where_string(self, engine_type: str) -> str:
def get_where_string(self, engine_type: str) -> QueryString:
return self.get_select_string(
engine_type=engine_type, with_alias=False
)
Expand Down Expand Up @@ -902,6 +883,13 @@ def get_sql_value(self, value: t.Any) -> t.Any:
def column_type(self):
return self.__class__.__name__.upper()

@property
def table_alias(self) -> str:
return "$".join(
f"{_key._meta.table._meta.tablename}${_key._meta.name}"
for _key in [*self._meta.call_chain, self]
)

@property
def ddl(self) -> str:
"""
Expand Down Expand Up @@ -945,8 +933,8 @@ def ddl(self) -> str:

return query

def copy(self) -> Column:
column: Column = copy.copy(self)
def copy(self: Self) -> Self:
column = copy.copy(self)
column._meta = self._meta.copy()
return column

Expand All @@ -971,3 +959,6 @@ def __repr__(self):
f"{table_class_name}.{self._meta.name} - "
f"{self.__class__.__name__}"
)


Self = t.TypeVar("Self", bound=Column)
19 changes: 11 additions & 8 deletions piccolo/columns/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Band(Table):
from piccolo.columns.operators.comparison import ArrayAll, ArrayAny
from piccolo.columns.operators.string import Concat
from piccolo.columns.reference import LazyTableReference
from piccolo.querystring import QueryString, Unquoted
from piccolo.querystring import QueryString
from piccolo.utils.encoding import dump_json
from piccolo.utils.warnings import colored_warning

Expand Down Expand Up @@ -752,8 +752,8 @@ def __set__(self, obj, value: t.Union[int, None]):
###############################################################################


DEFAULT = Unquoted("DEFAULT")
NULL = Unquoted("null")
DEFAULT = QueryString("DEFAULT")
NULL = QueryString("null")


class Serial(Column):
Expand All @@ -778,7 +778,7 @@ def default(self):
if engine_type == "postgres":
return DEFAULT
elif engine_type == "cockroach":
return Unquoted("unique_rowid()")
return QueryString("unique_rowid()")
elif engine_type == "sqlite":
return NULL
raise Exception("Unrecognized engine type")
Expand Down Expand Up @@ -2194,6 +2194,7 @@ def __getattribute__(self, name: str) -> t.Union[Column, t.Any]:
column_meta: ColumnMeta = object.__getattribute__(self, "_meta")

new_column._meta.call_chain = column_meta.call_chain.copy()

new_column._meta.call_chain.append(self)
return new_column
else:
Expand Down Expand Up @@ -2311,7 +2312,7 @@ def arrow(self, key: str) -> JSONB:

def get_select_string(
self, engine_type: str, with_alias: bool = True
) -> str:
) -> QueryString:
select_string = self._meta.get_full_name(with_alias=False)

if self.json_operator is not None:
Expand All @@ -2321,7 +2322,7 @@ def get_select_string(
alias = self._alias or self._meta.get_default_alias()
select_string += f' AS "{alias}"'

return select_string
return QueryString(select_string)

def eq(self, value) -> Where:
"""
Expand Down Expand Up @@ -2616,7 +2617,9 @@ def __getitem__(self, value: int) -> Array:
else:
raise ValueError("Only integers can be used for indexing.")

def get_select_string(self, engine_type: str, with_alias=True) -> str:
def get_select_string(
self, engine_type: str, with_alias=True
) -> QueryString:
select_string = self._meta.get_full_name(with_alias=False)

if isinstance(self.index, int):
Expand All @@ -2626,7 +2629,7 @@ def get_select_string(self, engine_type: str, with_alias=True) -> str:
alias = self._alias or self._meta.get_default_alias()
select_string += f' AS "{alias}"'

return select_string
return QueryString(select_string)

def any(self, value: t.Any) -> Where:
"""
Expand Down
Loading

0 comments on commit a85b404

Please sign in to comment.