start-pack
This commit is contained in:
commit
3e1fa59b3d
5723 changed files with 757971 additions and 0 deletions
|
|
@ -0,0 +1,193 @@
|
|||
from .comparison import Cast, Coalesce, Collate, Greatest, Least, NullIf
|
||||
from .datetime import (
|
||||
Extract,
|
||||
ExtractDay,
|
||||
ExtractHour,
|
||||
ExtractIsoWeekDay,
|
||||
ExtractIsoYear,
|
||||
ExtractMinute,
|
||||
ExtractMonth,
|
||||
ExtractQuarter,
|
||||
ExtractSecond,
|
||||
ExtractWeek,
|
||||
ExtractWeekDay,
|
||||
ExtractYear,
|
||||
Now,
|
||||
Trunc,
|
||||
TruncDate,
|
||||
TruncDay,
|
||||
TruncHour,
|
||||
TruncMinute,
|
||||
TruncMonth,
|
||||
TruncQuarter,
|
||||
TruncSecond,
|
||||
TruncTime,
|
||||
TruncWeek,
|
||||
TruncYear,
|
||||
)
|
||||
from .json import JSONArray, JSONObject
|
||||
from .math import (
|
||||
Abs,
|
||||
ACos,
|
||||
ASin,
|
||||
ATan,
|
||||
ATan2,
|
||||
Ceil,
|
||||
Cos,
|
||||
Cot,
|
||||
Degrees,
|
||||
Exp,
|
||||
Floor,
|
||||
Ln,
|
||||
Log,
|
||||
Mod,
|
||||
Pi,
|
||||
Power,
|
||||
Radians,
|
||||
Random,
|
||||
Round,
|
||||
Sign,
|
||||
Sin,
|
||||
Sqrt,
|
||||
Tan,
|
||||
)
|
||||
from .text import (
|
||||
MD5,
|
||||
SHA1,
|
||||
SHA224,
|
||||
SHA256,
|
||||
SHA384,
|
||||
SHA512,
|
||||
Chr,
|
||||
Concat,
|
||||
ConcatPair,
|
||||
Left,
|
||||
Length,
|
||||
Lower,
|
||||
LPad,
|
||||
LTrim,
|
||||
Ord,
|
||||
Repeat,
|
||||
Replace,
|
||||
Reverse,
|
||||
Right,
|
||||
RPad,
|
||||
RTrim,
|
||||
StrIndex,
|
||||
Substr,
|
||||
Trim,
|
||||
Upper,
|
||||
)
|
||||
from .window import (
|
||||
CumeDist,
|
||||
DenseRank,
|
||||
FirstValue,
|
||||
Lag,
|
||||
LastValue,
|
||||
Lead,
|
||||
NthValue,
|
||||
Ntile,
|
||||
PercentRank,
|
||||
Rank,
|
||||
RowNumber,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# comparison and conversion
|
||||
"Cast",
|
||||
"Coalesce",
|
||||
"Collate",
|
||||
"Greatest",
|
||||
"Least",
|
||||
"NullIf",
|
||||
# datetime
|
||||
"Extract",
|
||||
"ExtractDay",
|
||||
"ExtractHour",
|
||||
"ExtractMinute",
|
||||
"ExtractMonth",
|
||||
"ExtractQuarter",
|
||||
"ExtractSecond",
|
||||
"ExtractWeek",
|
||||
"ExtractIsoWeekDay",
|
||||
"ExtractWeekDay",
|
||||
"ExtractIsoYear",
|
||||
"ExtractYear",
|
||||
"Now",
|
||||
"Trunc",
|
||||
"TruncDate",
|
||||
"TruncDay",
|
||||
"TruncHour",
|
||||
"TruncMinute",
|
||||
"TruncMonth",
|
||||
"TruncQuarter",
|
||||
"TruncSecond",
|
||||
"TruncTime",
|
||||
"TruncWeek",
|
||||
"TruncYear",
|
||||
# json
|
||||
"JSONArray",
|
||||
"JSONObject",
|
||||
# math
|
||||
"Abs",
|
||||
"ACos",
|
||||
"ASin",
|
||||
"ATan",
|
||||
"ATan2",
|
||||
"Ceil",
|
||||
"Cos",
|
||||
"Cot",
|
||||
"Degrees",
|
||||
"Exp",
|
||||
"Floor",
|
||||
"Ln",
|
||||
"Log",
|
||||
"Mod",
|
||||
"Pi",
|
||||
"Power",
|
||||
"Radians",
|
||||
"Random",
|
||||
"Round",
|
||||
"Sign",
|
||||
"Sin",
|
||||
"Sqrt",
|
||||
"Tan",
|
||||
# text
|
||||
"MD5",
|
||||
"SHA1",
|
||||
"SHA224",
|
||||
"SHA256",
|
||||
"SHA384",
|
||||
"SHA512",
|
||||
"Chr",
|
||||
"Concat",
|
||||
"ConcatPair",
|
||||
"Left",
|
||||
"Length",
|
||||
"Lower",
|
||||
"LPad",
|
||||
"LTrim",
|
||||
"Ord",
|
||||
"Repeat",
|
||||
"Replace",
|
||||
"Reverse",
|
||||
"Right",
|
||||
"RPad",
|
||||
"RTrim",
|
||||
"StrIndex",
|
||||
"Substr",
|
||||
"Trim",
|
||||
"Upper",
|
||||
# window
|
||||
"CumeDist",
|
||||
"DenseRank",
|
||||
"FirstValue",
|
||||
"Lag",
|
||||
"LastValue",
|
||||
"Lead",
|
||||
"NthValue",
|
||||
"Ntile",
|
||||
"PercentRank",
|
||||
"Rank",
|
||||
"RowNumber",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,172 @@
|
|||
"""Database functions that do comparisons or type conversions."""
|
||||
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
|
||||
class Cast(Func):
|
||||
"""Coerce an expression to a new field type."""
|
||||
|
||||
function = "CAST"
|
||||
template = "%(function)s(%(expressions)s AS %(db_type)s)"
|
||||
|
||||
def __init__(self, expression, output_field):
|
||||
super().__init__(expression, output_field=output_field)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context["db_type"] = self.output_field.cast_db_type(connection)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
db_type = self.output_field.db_type(connection)
|
||||
if db_type in {"datetime", "time"}:
|
||||
# Use strftime as datetime/time don't keep fractional seconds.
|
||||
template = "strftime(%%s, %(expressions)s)"
|
||||
sql, params = super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
|
||||
params.insert(0, format_string)
|
||||
return sql, params
|
||||
elif db_type == "date":
|
||||
template = "date(%(expressions)s)"
|
||||
return super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
template = None
|
||||
output_type = self.output_field.get_internal_type()
|
||||
# MySQL doesn't support explicit cast to float.
|
||||
if output_type == "FloatField":
|
||||
template = "(%(expressions)s + 0.0)"
|
||||
# MariaDB doesn't support explicit cast to JSON.
|
||||
elif output_type == "JSONField" and connection.mysql_is_mariadb:
|
||||
template = "JSON_EXTRACT(%(expressions)s, '$')"
|
||||
return self.as_sql(compiler, connection, template=template, **extra_context)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# CAST would be valid too, but the :: shortcut syntax is more readable.
|
||||
# 'expressions' is wrapped in parentheses in case it's a complex
|
||||
# expression.
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="(%(expressions)s)::%(db_type)s",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
if self.output_field.get_internal_type() == "JSONField":
|
||||
# Oracle doesn't support explicit cast to JSON.
|
||||
template = "JSON_QUERY(%(expressions)s, '$')"
|
||||
return super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Coalesce(Func):
|
||||
"""Return, from left to right, the first non-null expression."""
|
||||
|
||||
function = "COALESCE"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Coalesce must take at least two expressions")
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
@property
|
||||
def empty_result_set_value(self):
|
||||
for expression in self.get_source_expressions():
|
||||
result = expression.empty_result_set_value
|
||||
if result is NotImplemented or result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
|
||||
# so convert all fields to NCLOB when that type is expected.
|
||||
if self.output_field.get_internal_type() == "TextField":
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[
|
||||
Func(expression, function="TO_NCLOB")
|
||||
for expression in self.get_source_expressions()
|
||||
]
|
||||
)
|
||||
return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Collate(Func):
|
||||
function = "COLLATE"
|
||||
template = "%(expressions)s %(function)s %(collation)s"
|
||||
allowed_default = False
|
||||
# Inspired from
|
||||
# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
|
||||
collation_re = _lazy_re_compile(r"^[\w-]+$")
|
||||
|
||||
def __init__(self, expression, collation):
|
||||
if not (collation and self.collation_re.match(collation)):
|
||||
raise ValueError("Invalid collation name: %r." % collation)
|
||||
self.collation = collation
|
||||
super().__init__(expression)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context.setdefault("collation", connection.ops.quote_name(self.collation))
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Greatest(Func):
|
||||
"""
|
||||
Return the maximum expression.
|
||||
|
||||
If any expression is null the return value is database-specific:
|
||||
On PostgreSQL, the maximum not-null expression is returned.
|
||||
On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
|
||||
"""
|
||||
|
||||
function = "GREATEST"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Greatest must take at least two expressions")
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
"""Use the MAX function on SQLite."""
|
||||
return super().as_sqlite(compiler, connection, function="MAX", **extra_context)
|
||||
|
||||
|
||||
class Least(Func):
|
||||
"""
|
||||
Return the minimum expression.
|
||||
|
||||
If any expression is null the return value is database-specific:
|
||||
On PostgreSQL, return the minimum not-null expression.
|
||||
On MySQL, Oracle, and SQLite, if any expression is null, return null.
|
||||
"""
|
||||
|
||||
function = "LEAST"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Least must take at least two expressions")
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
"""Use the MIN function on SQLite."""
|
||||
return super().as_sqlite(compiler, connection, function="MIN", **extra_context)
|
||||
|
||||
|
||||
class NullIf(Func):
|
||||
function = "NULLIF"
|
||||
arity = 2
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
expression1 = self.get_source_expressions()[0]
|
||||
if isinstance(expression1, Value) and expression1.value is None:
|
||||
raise ValueError("Oracle does not allow Value(None) for expression1.")
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
|
@ -0,0 +1,439 @@
|
|||
from datetime import datetime
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.models.expressions import Func
|
||||
from django.db.models.fields import (
|
||||
DateField,
|
||||
DateTimeField,
|
||||
DurationField,
|
||||
Field,
|
||||
IntegerField,
|
||||
TimeField,
|
||||
)
|
||||
from django.db.models.lookups import (
|
||||
Transform,
|
||||
YearExact,
|
||||
YearGt,
|
||||
YearGte,
|
||||
YearLt,
|
||||
YearLte,
|
||||
)
|
||||
from django.utils import timezone
|
||||
|
||||
|
||||
class TimezoneMixin:
|
||||
tzinfo = None
|
||||
|
||||
def get_tzname(self):
|
||||
# Timezone conversions must happen to the input datetime *before*
|
||||
# applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
|
||||
# database as 2016-01-01 01:00:00 +00:00. Any results should be
|
||||
# based on the input datetime not the stored datetime.
|
||||
tzname = None
|
||||
if settings.USE_TZ:
|
||||
if self.tzinfo is None:
|
||||
tzname = timezone.get_current_timezone_name()
|
||||
else:
|
||||
tzname = timezone._get_timezone_name(self.tzinfo)
|
||||
return tzname
|
||||
|
||||
|
||||
class Extract(TimezoneMixin, Transform):
|
||||
lookup_name = None
|
||||
output_field = IntegerField()
|
||||
|
||||
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
|
||||
if self.lookup_name is None:
|
||||
self.lookup_name = lookup_name
|
||||
if self.lookup_name is None:
|
||||
raise ValueError("lookup_name must be provided")
|
||||
self.tzinfo = tzinfo
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
lhs_output_field = self.lhs.output_field
|
||||
if isinstance(lhs_output_field, DateTimeField):
|
||||
tzname = self.get_tzname()
|
||||
sql, params = connection.ops.datetime_extract_sql(
|
||||
self.lookup_name, sql, tuple(params), tzname
|
||||
)
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError("tzinfo can only be used with DateTimeField.")
|
||||
elif isinstance(lhs_output_field, DateField):
|
||||
sql, params = connection.ops.date_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
elif isinstance(lhs_output_field, TimeField):
|
||||
sql, params = connection.ops.time_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
elif isinstance(lhs_output_field, DurationField):
|
||||
if not connection.features.has_native_duration_field:
|
||||
raise ValueError(
|
||||
"Extract requires native DurationField database support."
|
||||
)
|
||||
sql, params = connection.ops.time_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
else:
|
||||
# resolve_expression has already validated the output_field so this
|
||||
# assert should never be hit.
|
||||
assert False, "Tried to Extract from an invalid type."
|
||||
return sql, params
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
copy = super().resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
field = getattr(copy.lhs, "output_field", None)
|
||||
if field is None:
|
||||
return copy
|
||||
if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
|
||||
raise ValueError(
|
||||
"Extract input expression must be DateField, DateTimeField, "
|
||||
"TimeField, or DurationField."
|
||||
)
|
||||
# Passing dates to functions expecting datetimes is most likely a mistake.
|
||||
if type(field) is DateField and copy.lookup_name in (
|
||||
"hour",
|
||||
"minute",
|
||||
"second",
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot extract time component '%s' from DateField '%s'."
|
||||
% (copy.lookup_name, field.name)
|
||||
)
|
||||
if isinstance(field, DurationField) and copy.lookup_name in (
|
||||
"year",
|
||||
"iso_year",
|
||||
"month",
|
||||
"week",
|
||||
"week_day",
|
||||
"iso_week_day",
|
||||
"quarter",
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot extract component '%s' from DurationField '%s'."
|
||||
% (copy.lookup_name, field.name)
|
||||
)
|
||||
return copy
|
||||
|
||||
|
||||
class ExtractYear(Extract):
|
||||
lookup_name = "year"
|
||||
|
||||
|
||||
class ExtractIsoYear(Extract):
|
||||
"""Return the ISO-8601 week-numbering year."""
|
||||
|
||||
lookup_name = "iso_year"
|
||||
|
||||
|
||||
class ExtractMonth(Extract):
|
||||
lookup_name = "month"
|
||||
|
||||
|
||||
class ExtractDay(Extract):
|
||||
lookup_name = "day"
|
||||
|
||||
|
||||
class ExtractWeek(Extract):
|
||||
"""
|
||||
Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
|
||||
week.
|
||||
"""
|
||||
|
||||
lookup_name = "week"
|
||||
|
||||
|
||||
class ExtractWeekDay(Extract):
|
||||
"""
|
||||
Return Sunday=1 through Saturday=7.
|
||||
|
||||
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
|
||||
"""
|
||||
|
||||
lookup_name = "week_day"
|
||||
|
||||
|
||||
class ExtractIsoWeekDay(Extract):
|
||||
"""Return Monday=1 through Sunday=7, based on ISO-8601."""
|
||||
|
||||
lookup_name = "iso_week_day"
|
||||
|
||||
|
||||
class ExtractQuarter(Extract):
|
||||
lookup_name = "quarter"
|
||||
|
||||
|
||||
class ExtractHour(Extract):
|
||||
lookup_name = "hour"
|
||||
|
||||
|
||||
class ExtractMinute(Extract):
|
||||
lookup_name = "minute"
|
||||
|
||||
|
||||
class ExtractSecond(Extract):
|
||||
lookup_name = "second"
|
||||
|
||||
|
||||
DateField.register_lookup(ExtractYear)
|
||||
DateField.register_lookup(ExtractMonth)
|
||||
DateField.register_lookup(ExtractDay)
|
||||
DateField.register_lookup(ExtractWeekDay)
|
||||
DateField.register_lookup(ExtractIsoWeekDay)
|
||||
DateField.register_lookup(ExtractWeek)
|
||||
DateField.register_lookup(ExtractIsoYear)
|
||||
DateField.register_lookup(ExtractQuarter)
|
||||
|
||||
TimeField.register_lookup(ExtractHour)
|
||||
TimeField.register_lookup(ExtractMinute)
|
||||
TimeField.register_lookup(ExtractSecond)
|
||||
|
||||
DateTimeField.register_lookup(ExtractHour)
|
||||
DateTimeField.register_lookup(ExtractMinute)
|
||||
DateTimeField.register_lookup(ExtractSecond)
|
||||
|
||||
ExtractYear.register_lookup(YearExact)
|
||||
ExtractYear.register_lookup(YearGt)
|
||||
ExtractYear.register_lookup(YearGte)
|
||||
ExtractYear.register_lookup(YearLt)
|
||||
ExtractYear.register_lookup(YearLte)
|
||||
|
||||
ExtractIsoYear.register_lookup(YearExact)
|
||||
ExtractIsoYear.register_lookup(YearGt)
|
||||
ExtractIsoYear.register_lookup(YearGte)
|
||||
ExtractIsoYear.register_lookup(YearLt)
|
||||
ExtractIsoYear.register_lookup(YearLte)
|
||||
|
||||
|
||||
class Now(Func):
|
||||
template = "CURRENT_TIMESTAMP"
|
||||
output_field = DateTimeField()
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
|
||||
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
|
||||
# other databases.
|
||||
return self.as_sql(
|
||||
compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
|
||||
)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
|
||||
)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="LOCALTIMESTAMP", **extra_context
|
||||
)
|
||||
|
||||
|
||||
class TruncBase(TimezoneMixin, Transform):
|
||||
kind = None
|
||||
tzinfo = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
expression,
|
||||
output_field=None,
|
||||
tzinfo=None,
|
||||
**extra,
|
||||
):
|
||||
self.tzinfo = tzinfo
|
||||
super().__init__(expression, output_field=output_field, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = None
|
||||
if isinstance(self.lhs.output_field, DateTimeField):
|
||||
tzname = self.get_tzname()
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError("tzinfo can only be used with DateTimeField.")
|
||||
if isinstance(self.output_field, DateTimeField):
|
||||
sql, params = connection.ops.datetime_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
elif isinstance(self.output_field, DateField):
|
||||
sql, params = connection.ops.date_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
elif isinstance(self.output_field, TimeField):
|
||||
sql, params = connection.ops.time_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Trunc only valid on DateField, TimeField, or DateTimeField."
|
||||
)
|
||||
return sql, params
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
copy = super().resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
field = copy.lhs.output_field
|
||||
# DateTimeField is a subclass of DateField so this works for both.
|
||||
if not isinstance(field, (DateField, TimeField)):
|
||||
raise TypeError(
|
||||
"%r isn't a DateField, TimeField, or DateTimeField." % field.name
|
||||
)
|
||||
# If self.output_field was None, then accessing the field will trigger
|
||||
# the resolver to assign it to self.lhs.output_field.
|
||||
if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
|
||||
raise ValueError(
|
||||
"output_field must be either DateField, TimeField, or DateTimeField"
|
||||
)
|
||||
# Passing dates or times to functions expecting datetimes is most
|
||||
# likely a mistake.
|
||||
class_output_field = (
|
||||
self.__class__.output_field
|
||||
if isinstance(self.__class__.output_field, Field)
|
||||
else None
|
||||
)
|
||||
output_field = class_output_field or copy.output_field
|
||||
has_explicit_output_field = (
|
||||
class_output_field or field.__class__ is not copy.output_field.__class__
|
||||
)
|
||||
if type(field) is DateField and (
|
||||
isinstance(output_field, DateTimeField)
|
||||
or copy.kind in ("hour", "minute", "second", "time")
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot truncate DateField '%s' to %s."
|
||||
% (
|
||||
field.name,
|
||||
(
|
||||
output_field.__class__.__name__
|
||||
if has_explicit_output_field
|
||||
else "DateTimeField"
|
||||
),
|
||||
)
|
||||
)
|
||||
elif isinstance(field, TimeField) and (
|
||||
isinstance(output_field, DateTimeField)
|
||||
or copy.kind in ("year", "quarter", "month", "week", "day", "date")
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot truncate TimeField '%s' to %s."
|
||||
% (
|
||||
field.name,
|
||||
(
|
||||
output_field.__class__.__name__
|
||||
if has_explicit_output_field
|
||||
else "DateTimeField"
|
||||
),
|
||||
)
|
||||
)
|
||||
return copy
|
||||
|
||||
def convert_value(self, value, expression, connection):
|
||||
if isinstance(self.output_field, DateTimeField):
|
||||
if not settings.USE_TZ:
|
||||
pass
|
||||
elif value is not None:
|
||||
value = value.replace(tzinfo=None)
|
||||
value = timezone.make_aware(value, self.tzinfo)
|
||||
elif not connection.features.has_zoneinfo_database:
|
||||
raise ValueError(
|
||||
"Database returned an invalid datetime value. Are time "
|
||||
"zone definitions for your database installed?"
|
||||
)
|
||||
elif isinstance(value, datetime):
|
||||
if value is None:
|
||||
pass
|
||||
elif isinstance(self.output_field, DateField):
|
||||
value = value.date()
|
||||
elif isinstance(self.output_field, TimeField):
|
||||
value = value.time()
|
||||
return value
|
||||
|
||||
|
||||
class Trunc(TruncBase):
|
||||
def __init__(
|
||||
self,
|
||||
expression,
|
||||
kind,
|
||||
output_field=None,
|
||||
tzinfo=None,
|
||||
**extra,
|
||||
):
|
||||
self.kind = kind
|
||||
super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
|
||||
|
||||
|
||||
class TruncYear(TruncBase):
|
||||
kind = "year"
|
||||
|
||||
|
||||
class TruncQuarter(TruncBase):
|
||||
kind = "quarter"
|
||||
|
||||
|
||||
class TruncMonth(TruncBase):
|
||||
kind = "month"
|
||||
|
||||
|
||||
class TruncWeek(TruncBase):
|
||||
"""Truncate to midnight on the Monday of the week."""
|
||||
|
||||
kind = "week"
|
||||
|
||||
|
||||
class TruncDay(TruncBase):
|
||||
kind = "day"
|
||||
|
||||
|
||||
class TruncDate(TruncBase):
|
||||
kind = "date"
|
||||
lookup_name = "date"
|
||||
output_field = DateField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to date rather than truncate to date.
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = self.get_tzname()
|
||||
return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
|
||||
|
||||
|
||||
class TruncTime(TruncBase):
|
||||
kind = "time"
|
||||
lookup_name = "time"
|
||||
output_field = TimeField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to time rather than truncate to time.
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = self.get_tzname()
|
||||
return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
|
||||
|
||||
|
||||
class TruncHour(TruncBase):
|
||||
kind = "hour"
|
||||
|
||||
|
||||
class TruncMinute(TruncBase):
|
||||
kind = "minute"
|
||||
|
||||
|
||||
class TruncSecond(TruncBase):
|
||||
kind = "second"
|
||||
|
||||
|
||||
DateTimeField.register_lookup(TruncDate)
|
||||
DateTimeField.register_lookup(TruncTime)
|
||||
|
|
@ -0,0 +1,124 @@
|
|||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import TextField
|
||||
from django.db.models.fields.json import JSONField
|
||||
from django.db.models.functions import Cast
|
||||
|
||||
|
||||
class JSONArray(Func):
|
||||
function = "JSON_ARRAY"
|
||||
output_field = JSONField()
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if not connection.features.supports_json_field:
|
||||
raise NotSupportedError(
|
||||
"JSONFields are not supported on this database backend."
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_native(self, compiler, connection, *, returning, **extra_context):
|
||||
# PostgreSQL 16+ and Oracle remove SQL NULL values from the array by
|
||||
# default. Adds the NULL ON NULL clause to keep NULL values in the
|
||||
# array, mapping them to JSON null values, which matches the behavior
|
||||
# of SQLite.
|
||||
null_on_null = "NULL ON NULL" if len(self.get_source_expressions()) > 0 else ""
|
||||
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=(
|
||||
f"%(function)s(%(expressions)s {null_on_null} RETURNING {returning})"
|
||||
),
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# Casting source expressions is only required using JSONB_BUILD_ARRAY
|
||||
# or when using JSON_ARRAY on PostgreSQL 16+ with server-side bindings.
|
||||
# This is done in all cases for consistency.
|
||||
casted_obj = self.copy()
|
||||
casted_obj.set_source_expressions(
|
||||
[
|
||||
(
|
||||
# Conditional Cast to avoid unnecessary wrapping.
|
||||
expression
|
||||
if isinstance(expression, Cast)
|
||||
else Cast(expression, expression.output_field)
|
||||
)
|
||||
for expression in casted_obj.get_source_expressions()
|
||||
]
|
||||
)
|
||||
|
||||
if connection.features.is_postgresql_16:
|
||||
return casted_obj.as_native(
|
||||
compiler, connection, returning="JSONB", **extra_context
|
||||
)
|
||||
|
||||
return casted_obj.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="JSONB_BUILD_ARRAY",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return self.as_native(compiler, connection, returning="CLOB", **extra_context)
|
||||
|
||||
|
||||
class JSONObject(Func):
|
||||
function = "JSON_OBJECT"
|
||||
output_field = JSONField()
|
||||
|
||||
def __init__(self, **fields):
|
||||
expressions = []
|
||||
for key, value in fields.items():
|
||||
expressions.extend((Value(key), value))
|
||||
super().__init__(*expressions)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if not connection.features.has_json_object_function:
|
||||
raise NotSupportedError(
|
||||
"JSONObject() is not supported on this database backend."
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def join(self, args):
|
||||
pairs = zip(args[::2], args[1::2], strict=True)
|
||||
# Wrap 'key' in parentheses in case of postgres cast :: syntax.
|
||||
return ", ".join([f"({key}) VALUE {value}" for key, value in pairs])
|
||||
|
||||
def as_native(self, compiler, connection, *, returning, **extra_context):
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
arg_joiner=self,
|
||||
template=f"%(function)s(%(expressions)s RETURNING {returning})",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# Casting keys to text is only required when using JSONB_BUILD_OBJECT
|
||||
# or when using JSON_OBJECT on PostgreSQL 16+ with server-side bindings.
|
||||
# This is done in all cases for consistency.
|
||||
copy = self.copy()
|
||||
copy.set_source_expressions(
|
||||
[
|
||||
Cast(expression, TextField()) if index % 2 == 0 else expression
|
||||
for index, expression in enumerate(copy.get_source_expressions())
|
||||
]
|
||||
)
|
||||
|
||||
if connection.features.is_postgresql_16:
|
||||
return copy.as_native(
|
||||
compiler, connection, returning="JSONB", **extra_context
|
||||
)
|
||||
|
||||
return super(JSONObject, copy).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="JSONB_BUILD_OBJECT",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return self.as_native(compiler, connection, returning="CLOB", **extra_context)
|
||||
|
|
@ -0,0 +1,214 @@
|
|||
import math
|
||||
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import FloatField, IntegerField
|
||||
from django.db.models.functions import Cast
|
||||
from django.db.models.functions.mixins import (
|
||||
FixDecimalInputMixin,
|
||||
NumericOutputFieldMixin,
|
||||
)
|
||||
from django.db.models.lookups import Transform
|
||||
|
||||
|
||||
class Abs(Transform):
|
||||
function = "ABS"
|
||||
lookup_name = "abs"
|
||||
|
||||
|
||||
class ACos(NumericOutputFieldMixin, Transform):
|
||||
function = "ACOS"
|
||||
lookup_name = "acos"
|
||||
|
||||
|
||||
class ASin(NumericOutputFieldMixin, Transform):
|
||||
function = "ASIN"
|
||||
lookup_name = "asin"
|
||||
|
||||
|
||||
class ATan(NumericOutputFieldMixin, Transform):
|
||||
function = "ATAN"
|
||||
lookup_name = "atan"
|
||||
|
||||
|
||||
class ATan2(NumericOutputFieldMixin, Func):
|
||||
function = "ATAN2"
|
||||
arity = 2
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if not getattr(
|
||||
connection.ops, "spatialite", False
|
||||
) or connection.ops.spatial_version >= (5, 0, 0):
|
||||
return self.as_sql(compiler, connection)
|
||||
# This function is usually ATan2(y, x), returning the inverse tangent
|
||||
# of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
|
||||
# Cast integers to float to avoid inconsistent/buggy behavior if the
|
||||
# arguments are mixed between integer and float or decimal.
|
||||
# https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[
|
||||
(
|
||||
Cast(expression, FloatField())
|
||||
if isinstance(expression.output_field, IntegerField)
|
||||
else expression
|
||||
)
|
||||
for expression in self.get_source_expressions()[::-1]
|
||||
]
|
||||
)
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Ceil(Transform):
|
||||
function = "CEILING"
|
||||
lookup_name = "ceil"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="CEIL", **extra_context)
|
||||
|
||||
|
||||
class Cos(NumericOutputFieldMixin, Transform):
|
||||
function = "COS"
|
||||
lookup_name = "cos"
|
||||
|
||||
|
||||
class Cot(NumericOutputFieldMixin, Transform):
|
||||
function = "COT"
|
||||
lookup_name = "cot"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, template="(1 / TAN(%(expressions)s))", **extra_context
|
||||
)
|
||||
|
||||
|
||||
class Degrees(NumericOutputFieldMixin, Transform):
|
||||
function = "DEGREES"
|
||||
lookup_name = "degrees"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="((%%(expressions)s) * 180 / %s)" % math.pi,
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class Exp(NumericOutputFieldMixin, Transform):
|
||||
function = "EXP"
|
||||
lookup_name = "exp"
|
||||
|
||||
|
||||
class Floor(Transform):
|
||||
function = "FLOOR"
|
||||
lookup_name = "floor"
|
||||
|
||||
|
||||
class Ln(NumericOutputFieldMixin, Transform):
|
||||
function = "LN"
|
||||
lookup_name = "ln"
|
||||
|
||||
|
||||
class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
||||
function = "LOG"
|
||||
arity = 2
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if not getattr(connection.ops, "spatialite", False):
|
||||
return self.as_sql(compiler, connection)
|
||||
# This function is usually Log(b, x) returning the logarithm of x to
|
||||
# the base b, but on SpatiaLite it's Log(x, b).
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(self.get_source_expressions()[::-1])
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
||||
function = "MOD"
|
||||
arity = 2
|
||||
|
||||
|
||||
class Pi(NumericOutputFieldMixin, Func):
|
||||
function = "PI"
|
||||
arity = 0
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, template=str(math.pi), **extra_context
|
||||
)
|
||||
|
||||
|
||||
class Power(NumericOutputFieldMixin, Func):
|
||||
function = "POWER"
|
||||
arity = 2
|
||||
|
||||
|
||||
class Radians(NumericOutputFieldMixin, Transform):
|
||||
function = "RADIANS"
|
||||
lookup_name = "radians"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="((%%(expressions)s) * %s / 180)" % math.pi,
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class Random(NumericOutputFieldMixin, Func):
|
||||
function = "RANDOM"
|
||||
arity = 0
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="RAND", **extra_context)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, function="DBMS_RANDOM.VALUE", **extra_context
|
||||
)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="RAND", **extra_context)
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return []
|
||||
|
||||
|
||||
class Round(FixDecimalInputMixin, Transform):
|
||||
function = "ROUND"
|
||||
lookup_name = "round"
|
||||
arity = None # Override Transform's arity=1 to enable passing precision.
|
||||
|
||||
def __init__(self, expression, precision=0, **extra):
|
||||
super().__init__(expression, precision, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
precision = self.get_source_expressions()[1]
|
||||
if isinstance(precision, Value) and precision.value < 0:
|
||||
raise ValueError("SQLite does not support negative precision.")
|
||||
return super().as_sqlite(compiler, connection, **extra_context)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
source = self.get_source_expressions()[0]
|
||||
return source.output_field
|
||||
|
||||
|
||||
class Sign(Transform):
|
||||
function = "SIGN"
|
||||
lookup_name = "sign"
|
||||
|
||||
|
||||
class Sin(NumericOutputFieldMixin, Transform):
|
||||
function = "SIN"
|
||||
lookup_name = "sin"
|
||||
|
||||
|
||||
class Sqrt(NumericOutputFieldMixin, Transform):
|
||||
function = "SQRT"
|
||||
lookup_name = "sqrt"
|
||||
|
||||
|
||||
class Tan(NumericOutputFieldMixin, Transform):
|
||||
function = "TAN"
|
||||
lookup_name = "tan"
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
import sys
|
||||
|
||||
from django.db.models.fields import DecimalField, FloatField, IntegerField
|
||||
from django.db.models.functions import Cast
|
||||
|
||||
|
||||
class FixDecimalInputMixin:
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
|
||||
# following function signatures:
|
||||
# - LOG(double, double)
|
||||
# - MOD(double, double)
|
||||
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[
|
||||
(
|
||||
Cast(expression, output_field)
|
||||
if isinstance(expression.output_field, FloatField)
|
||||
else expression
|
||||
)
|
||||
for expression in self.get_source_expressions()
|
||||
]
|
||||
)
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class FixDurationInputMixin:
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
sql, params = super().as_sql(compiler, connection, **extra_context)
|
||||
if self.output_field.get_internal_type() == "DurationField":
|
||||
sql = "CAST(%s AS SIGNED)" % sql
|
||||
return sql, params
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
if (
|
||||
self.output_field.get_internal_type() == "DurationField"
|
||||
and not connection.features.supports_aggregation_over_interval_types
|
||||
):
|
||||
expression = self.get_source_expressions()[0]
|
||||
options = self._get_repr_options()
|
||||
from django.db.backends.oracle.functions import (
|
||||
IntervalToSeconds,
|
||||
SecondsToInterval,
|
||||
)
|
||||
|
||||
return compiler.compile(
|
||||
SecondsToInterval(
|
||||
self.__class__(IntervalToSeconds(expression), **options)
|
||||
)
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class NumericOutputFieldMixin:
|
||||
def _resolve_output_field(self):
|
||||
source_fields = self.get_source_fields()
|
||||
if any(isinstance(s, DecimalField) for s in source_fields):
|
||||
return DecimalField()
|
||||
if any(isinstance(s, IntegerField) for s in source_fields):
|
||||
return FloatField()
|
||||
return super()._resolve_output_field() if source_fields else FloatField()
|
||||
|
|
@ -0,0 +1,376 @@
|
|||
from django.db import NotSupportedError
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.fields import CharField, IntegerField, TextField
|
||||
from django.db.models.functions import Cast, Coalesce
|
||||
from django.db.models.lookups import Transform
|
||||
|
||||
|
||||
class MySQLSHA2Mixin:
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="SHA2(%%(expressions)s, %s)" % self.function[3:],
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class OracleHashMixin:
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=(
|
||||
"LOWER(RAWTOHEX(STANDARD_HASH(UTL_I18N.STRING_TO_RAW("
|
||||
"%(expressions)s, 'AL32UTF8'), '%(function)s')))"
|
||||
),
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class PostgreSQLSHAMixin:
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
|
||||
function=self.function.lower(),
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
|
||||
class Chr(Transform):
|
||||
function = "CHR"
|
||||
lookup_name = "chr"
|
||||
output_field = CharField()
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="CHAR",
|
||||
template="%(function)s(%(expressions)s USING utf16)",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="%(function)s(%(expressions)s USING NCHAR_CS)",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="CHAR", **extra_context)
|
||||
|
||||
|
||||
class ConcatPair(Func):
|
||||
"""
|
||||
Concatenate two arguments together. This is used by `Concat` because not
|
||||
all backend databases support more than two arguments.
|
||||
"""
|
||||
|
||||
function = "CONCAT"
|
||||
|
||||
def pipes_concat_sql(self, compiler, connection, **extra_context):
|
||||
coalesced = self.coalesce()
|
||||
return super(ConcatPair, coalesced).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="(%(expressions)s)",
|
||||
arg_joiner=" || ",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
as_sqlite = pipes_concat_sql
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
c = self.copy()
|
||||
c.set_source_expressions(
|
||||
[
|
||||
(
|
||||
expression
|
||||
if isinstance(expression.output_field, (CharField, TextField))
|
||||
else Cast(expression, TextField())
|
||||
)
|
||||
for expression in c.get_source_expressions()
|
||||
]
|
||||
)
|
||||
return c.pipes_concat_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="CONCAT_WS",
|
||||
template="%(function)s('', %(expressions)s)",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def coalesce(self):
|
||||
# null on either side results in null for expression, wrap with coalesce
|
||||
c = self.copy()
|
||||
c.set_source_expressions(
|
||||
[
|
||||
Coalesce(expression, Value(""))
|
||||
for expression in c.get_source_expressions()
|
||||
]
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
class Concat(Func):
|
||||
"""
|
||||
Concatenate text fields together. Backends that result in an entire
|
||||
null expression when any arguments are null will wrap each argument in
|
||||
coalesce functions to ensure a non-null result.
|
||||
"""
|
||||
|
||||
function = None
|
||||
template = "%(expressions)s"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Concat must take at least two expressions")
|
||||
paired = self._paired(expressions, output_field=extra.get("output_field"))
|
||||
super().__init__(paired, **extra)
|
||||
|
||||
def _paired(self, expressions, output_field):
|
||||
# wrap pairs of expressions in successive concat functions
|
||||
# exp = [a, b, c, d]
|
||||
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
|
||||
if len(expressions) == 2:
|
||||
return ConcatPair(*expressions, output_field=output_field)
|
||||
return ConcatPair(
|
||||
expressions[0],
|
||||
self._paired(expressions[1:], output_field=output_field),
|
||||
output_field=output_field,
|
||||
)
|
||||
|
||||
|
||||
class Left(Func):
|
||||
function = "LEFT"
|
||||
arity = 2
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, length, **extra):
|
||||
"""
|
||||
expression: the name of a field, or an expression returning a string
|
||||
length: the number of characters to return from the start of the string
|
||||
"""
|
||||
if not hasattr(length, "resolve_expression"):
|
||||
if length < 1:
|
||||
raise ValueError("'length' must be greater than 0.")
|
||||
super().__init__(expression, length, **extra)
|
||||
|
||||
def get_substr(self):
|
||||
return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return self.get_substr().as_oracle(compiler, connection, **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return self.get_substr().as_sqlite(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Length(Transform):
|
||||
"""Return the number of characters in the expression."""
|
||||
|
||||
function = "LENGTH"
|
||||
lookup_name = "length"
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, function="CHAR_LENGTH", **extra_context
|
||||
)
|
||||
|
||||
|
||||
class Lower(Transform):
|
||||
function = "LOWER"
|
||||
lookup_name = "lower"
|
||||
|
||||
|
||||
class LPad(Func):
|
||||
function = "LPAD"
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, length, fill_text=Value(" "), **extra):
|
||||
if (
|
||||
not hasattr(length, "resolve_expression")
|
||||
and length is not None
|
||||
and length < 0
|
||||
):
|
||||
raise ValueError("'length' must be greater or equal to 0.")
|
||||
super().__init__(expression, length, fill_text, **extra)
|
||||
|
||||
|
||||
class LTrim(Transform):
|
||||
function = "LTRIM"
|
||||
lookup_name = "ltrim"
|
||||
|
||||
|
||||
class MD5(OracleHashMixin, Transform):
|
||||
function = "MD5"
|
||||
lookup_name = "md5"
|
||||
|
||||
|
||||
class Ord(Transform):
|
||||
function = "ASCII"
|
||||
lookup_name = "ord"
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="ORD", **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
|
||||
|
||||
|
||||
class Repeat(Func):
|
||||
function = "REPEAT"
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, number, **extra):
|
||||
if (
|
||||
not hasattr(number, "resolve_expression")
|
||||
and number is not None
|
||||
and number < 0
|
||||
):
|
||||
raise ValueError("'number' must be greater or equal to 0.")
|
||||
super().__init__(expression, number, **extra)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
expression, number = self.source_expressions
|
||||
length = None if number is None else Length(expression) * number
|
||||
rpad = RPad(expression, length, expression)
|
||||
return rpad.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Replace(Func):
|
||||
function = "REPLACE"
|
||||
|
||||
def __init__(self, expression, text, replacement=Value(""), **extra):
|
||||
super().__init__(expression, text, replacement, **extra)
|
||||
|
||||
|
||||
class Reverse(Transform):
|
||||
function = "REVERSE"
|
||||
lookup_name = "reverse"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
# REVERSE in Oracle is undocumented and doesn't support multi-byte
|
||||
# strings. Use a special subquery instead.
|
||||
suffix = connection.features.bare_select_suffix
|
||||
sql, params = super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=(
|
||||
"(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM "
|
||||
f"(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s{suffix} "
|
||||
"CONNECT BY LEVEL <= LENGTH(%(expressions)s)) "
|
||||
"GROUP BY %(expressions)s)"
|
||||
),
|
||||
**extra_context,
|
||||
)
|
||||
return sql, params * 3
|
||||
|
||||
|
||||
class Right(Left):
|
||||
function = "RIGHT"
|
||||
|
||||
def get_substr(self):
|
||||
return Substr(
|
||||
self.source_expressions[0],
|
||||
self.source_expressions[1] * Value(-1),
|
||||
self.source_expressions[1],
|
||||
)
|
||||
|
||||
|
||||
class RPad(LPad):
|
||||
function = "RPAD"
|
||||
|
||||
|
||||
class RTrim(Transform):
|
||||
function = "RTRIM"
|
||||
lookup_name = "rtrim"
|
||||
|
||||
|
||||
class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA1"
|
||||
lookup_name = "sha1"
|
||||
|
||||
|
||||
class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA224"
|
||||
lookup_name = "sha224"
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
raise NotSupportedError("SHA224 is not supported on Oracle.")
|
||||
|
||||
|
||||
class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA256"
|
||||
lookup_name = "sha256"
|
||||
|
||||
|
||||
class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA384"
|
||||
lookup_name = "sha384"
|
||||
|
||||
|
||||
class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA512"
|
||||
lookup_name = "sha512"
|
||||
|
||||
|
||||
class StrIndex(Func):
|
||||
"""
|
||||
Return a positive integer corresponding to the 1-indexed position of the
|
||||
first occurrence of a substring inside another string, or 0 if the
|
||||
substring is not found.
|
||||
"""
|
||||
|
||||
function = "INSTR"
|
||||
arity = 2
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
|
||||
|
||||
|
||||
class Substr(Func):
|
||||
function = "SUBSTRING"
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, pos, length=None, **extra):
|
||||
"""
|
||||
expression: the name of a field, or an expression returning a string
|
||||
pos: an integer > 0, or an expression returning an integer
|
||||
length: an optional number of characters to return
|
||||
"""
|
||||
if not hasattr(pos, "resolve_expression"):
|
||||
if pos < 1:
|
||||
raise ValueError("'pos' must be greater than 0")
|
||||
expressions = [expression, pos]
|
||||
if length is not None:
|
||||
expressions.append(length)
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
|
||||
|
||||
|
||||
class Trim(Transform):
|
||||
function = "TRIM"
|
||||
lookup_name = "trim"
|
||||
|
||||
|
||||
class Upper(Transform):
|
||||
function = "UPPER"
|
||||
lookup_name = "upper"
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
from django.db.models.expressions import Func
|
||||
from django.db.models.fields import FloatField, IntegerField
|
||||
|
||||
__all__ = [
|
||||
"CumeDist",
|
||||
"DenseRank",
|
||||
"FirstValue",
|
||||
"Lag",
|
||||
"LastValue",
|
||||
"Lead",
|
||||
"NthValue",
|
||||
"Ntile",
|
||||
"PercentRank",
|
||||
"Rank",
|
||||
"RowNumber",
|
||||
]
|
||||
|
||||
|
||||
class CumeDist(Func):
|
||||
function = "CUME_DIST"
|
||||
output_field = FloatField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class DenseRank(Func):
|
||||
function = "DENSE_RANK"
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class FirstValue(Func):
|
||||
arity = 1
|
||||
function = "FIRST_VALUE"
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class LagLeadFunction(Func):
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, expression, offset=1, default=None, **extra):
|
||||
if expression is None:
|
||||
raise ValueError(
|
||||
"%s requires a non-null source expression." % self.__class__.__name__
|
||||
)
|
||||
if offset is None or offset <= 0:
|
||||
raise ValueError(
|
||||
"%s requires a positive integer for the offset."
|
||||
% self.__class__.__name__
|
||||
)
|
||||
args = (expression, offset)
|
||||
if default is not None:
|
||||
args += (default,)
|
||||
super().__init__(*args, **extra)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
sources = self.get_source_expressions()
|
||||
return sources[0].output_field
|
||||
|
||||
|
||||
class Lag(LagLeadFunction):
|
||||
function = "LAG"
|
||||
|
||||
|
||||
class LastValue(Func):
|
||||
arity = 1
|
||||
function = "LAST_VALUE"
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class Lead(LagLeadFunction):
|
||||
function = "LEAD"
|
||||
|
||||
|
||||
class NthValue(Func):
|
||||
function = "NTH_VALUE"
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, expression, nth=1, **extra):
|
||||
if expression is None:
|
||||
raise ValueError(
|
||||
"%s requires a non-null source expression." % self.__class__.__name__
|
||||
)
|
||||
if nth is None or nth <= 0:
|
||||
raise ValueError(
|
||||
"%s requires a positive integer as for nth." % self.__class__.__name__
|
||||
)
|
||||
super().__init__(expression, nth, **extra)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
sources = self.get_source_expressions()
|
||||
return sources[0].output_field
|
||||
|
||||
|
||||
class Ntile(Func):
|
||||
function = "NTILE"
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, num_buckets=1, **extra):
|
||||
if num_buckets <= 0:
|
||||
raise ValueError("num_buckets must be greater than 0.")
|
||||
super().__init__(num_buckets, **extra)
|
||||
|
||||
|
||||
class PercentRank(Func):
|
||||
function = "PERCENT_RANK"
|
||||
output_field = FloatField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class Rank(Func):
|
||||
function = "RANK"
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class RowNumber(Func):
|
||||
function = "ROW_NUMBER"
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
Loading…
Add table
Add a link
Reference in a new issue