start-pack
This commit is contained in:
commit
3e1fa59b3d
5723 changed files with 757971 additions and 0 deletions
File diff suppressed because it is too large
Load diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,177 @@
|
|||
import json
|
||||
|
||||
from django.core import checks
|
||||
from django.db.models import NOT_PROVIDED, Field
|
||||
from django.db.models.expressions import ColPairs
|
||||
from django.db.models.fields.tuple_lookups import (
|
||||
TupleExact,
|
||||
TupleGreaterThan,
|
||||
TupleGreaterThanOrEqual,
|
||||
TupleIn,
|
||||
TupleIsNull,
|
||||
TupleLessThan,
|
||||
TupleLessThanOrEqual,
|
||||
)
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class AttributeSetter:
|
||||
def __init__(self, name, value):
|
||||
setattr(self, name, value)
|
||||
|
||||
|
||||
class CompositeAttribute:
|
||||
def __init__(self, field):
|
||||
self.field = field
|
||||
|
||||
@property
|
||||
def attnames(self):
|
||||
return [field.attname for field in self.field.fields]
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
return tuple(getattr(instance, attname) for attname in self.attnames)
|
||||
|
||||
def __set__(self, instance, values):
|
||||
attnames = self.attnames
|
||||
length = len(attnames)
|
||||
|
||||
if values is None:
|
||||
values = (None,) * length
|
||||
|
||||
if not isinstance(values, (list, tuple)):
|
||||
raise ValueError(f"{self.field.name!r} must be a list or a tuple.")
|
||||
if length != len(values):
|
||||
raise ValueError(f"{self.field.name!r} must have {length} elements.")
|
||||
|
||||
for attname, value in zip(attnames, values):
|
||||
setattr(instance, attname, value)
|
||||
|
||||
|
||||
class CompositePrimaryKey(Field):
|
||||
descriptor_class = CompositeAttribute
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if (
|
||||
not args
|
||||
or not all(isinstance(field, str) for field in args)
|
||||
or len(set(args)) != len(args)
|
||||
):
|
||||
raise ValueError("CompositePrimaryKey args must be unique strings.")
|
||||
if len(args) == 1:
|
||||
raise ValueError("CompositePrimaryKey must include at least two fields.")
|
||||
if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
|
||||
raise ValueError("CompositePrimaryKey cannot have a default.")
|
||||
if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
|
||||
raise ValueError("CompositePrimaryKey cannot have a database default.")
|
||||
if kwargs.get("db_column", None) is not None:
|
||||
raise ValueError("CompositePrimaryKey cannot have a db_column.")
|
||||
if kwargs.setdefault("editable", False):
|
||||
raise ValueError("CompositePrimaryKey cannot be editable.")
|
||||
if not kwargs.setdefault("primary_key", True):
|
||||
raise ValueError("CompositePrimaryKey must be a primary key.")
|
||||
if not kwargs.setdefault("blank", True):
|
||||
raise ValueError("CompositePrimaryKey must be blank.")
|
||||
|
||||
self.field_names = args
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def deconstruct(self):
|
||||
# args is always [] so it can be ignored.
|
||||
name, path, _, kwargs = super().deconstruct()
|
||||
return name, path, self.field_names, kwargs
|
||||
|
||||
@cached_property
|
||||
def fields(self):
|
||||
meta = self.model._meta
|
||||
return tuple(meta.get_field(field_name) for field_name in self.field_names)
|
||||
|
||||
@cached_property
|
||||
def columns(self):
|
||||
return tuple(field.column for field in self.fields)
|
||||
|
||||
def contribute_to_class(self, cls, name, private_only=False):
|
||||
super().contribute_to_class(cls, name, private_only=private_only)
|
||||
cls._meta.pk = self
|
||||
setattr(cls, self.attname, self.descriptor_class(self))
|
||||
|
||||
def get_attname_column(self):
|
||||
return self.get_attname(), None
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.fields)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.field_names)
|
||||
|
||||
@cached_property
|
||||
def cached_col(self):
|
||||
return ColPairs(self.model._meta.db_table, self.fields, self.fields, self)
|
||||
|
||||
def get_col(self, alias, output_field=None):
|
||||
if alias == self.model._meta.db_table and (
|
||||
output_field is None or output_field == self
|
||||
):
|
||||
return self.cached_col
|
||||
|
||||
return ColPairs(alias, self.fields, self.fields, output_field)
|
||||
|
||||
def get_pk_value_on_save(self, instance):
|
||||
values = []
|
||||
|
||||
for field in self.fields:
|
||||
value = field.value_from_object(instance)
|
||||
if value is None:
|
||||
value = field.get_pk_value_on_save(instance)
|
||||
values.append(value)
|
||||
|
||||
return tuple(values)
|
||||
|
||||
def _check_field_name(self):
|
||||
if self.name == "pk":
|
||||
return []
|
||||
return [
|
||||
checks.Error(
|
||||
"'CompositePrimaryKey' must be named 'pk'.",
|
||||
obj=self,
|
||||
id="fields.E013",
|
||||
)
|
||||
]
|
||||
|
||||
def value_to_string(self, obj):
|
||||
values = []
|
||||
vals = self.value_from_object(obj)
|
||||
for field, value in zip(self.fields, vals):
|
||||
obj = AttributeSetter(field.attname, value)
|
||||
values.append(field.value_to_string(obj))
|
||||
return json.dumps(values, ensure_ascii=False)
|
||||
|
||||
def to_python(self, value):
|
||||
if isinstance(value, str):
|
||||
# Assume we're deserializing.
|
||||
vals = json.loads(value)
|
||||
value = [
|
||||
field.to_python(val)
|
||||
for field, val in zip(self.fields, vals, strict=True)
|
||||
]
|
||||
return value
|
||||
|
||||
|
||||
CompositePrimaryKey.register_lookup(TupleExact)
|
||||
CompositePrimaryKey.register_lookup(TupleGreaterThan)
|
||||
CompositePrimaryKey.register_lookup(TupleGreaterThanOrEqual)
|
||||
CompositePrimaryKey.register_lookup(TupleLessThan)
|
||||
CompositePrimaryKey.register_lookup(TupleLessThanOrEqual)
|
||||
CompositePrimaryKey.register_lookup(TupleIn)
|
||||
CompositePrimaryKey.register_lookup(TupleIsNull)
|
||||
|
||||
|
||||
def unnest(fields):
|
||||
result = []
|
||||
|
||||
for field in fields:
|
||||
if isinstance(field, CompositePrimaryKey):
|
||||
result.extend(field.fields)
|
||||
else:
|
||||
result.append(field)
|
||||
|
||||
return result
|
||||
|
|
@ -0,0 +1,538 @@
|
|||
import datetime
|
||||
import posixpath
|
||||
|
||||
from django import forms
|
||||
from django.core import checks
|
||||
from django.core.exceptions import FieldError
|
||||
from django.core.files.base import ContentFile, File
|
||||
from django.core.files.images import ImageFile
|
||||
from django.core.files.storage import Storage, default_storage
|
||||
from django.core.files.utils import validate_file_name
|
||||
from django.db.models import signals
|
||||
from django.db.models.expressions import DatabaseDefault
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.query_utils import DeferredAttribute
|
||||
from django.db.models.utils import AltersData
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.version import PY311
|
||||
|
||||
|
||||
class FieldFile(File, AltersData):
|
||||
def __init__(self, instance, field, name):
|
||||
super().__init__(None, name)
|
||||
self.instance = instance
|
||||
self.field = field
|
||||
self.storage = field.storage
|
||||
self._committed = True
|
||||
|
||||
def __eq__(self, other):
|
||||
# Older code may be expecting FileField values to be simple strings.
|
||||
# By overriding the == operator, it can remain backwards compatibility.
|
||||
if hasattr(other, "name"):
|
||||
return self.name == other.name
|
||||
return self.name == other
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
# The standard File contains most of the necessary properties, but
|
||||
# FieldFiles can be instantiated without a name, so that needs to
|
||||
# be checked for here.
|
||||
|
||||
def _require_file(self):
|
||||
if not self:
|
||||
raise ValueError(
|
||||
"The '%s' attribute has no file associated with it." % self.field.name
|
||||
)
|
||||
|
||||
def _get_file(self):
|
||||
self._require_file()
|
||||
if getattr(self, "_file", None) is None:
|
||||
self._file = self.storage.open(self.name, "rb")
|
||||
return self._file
|
||||
|
||||
def _set_file(self, file):
|
||||
self._file = file
|
||||
|
||||
def _del_file(self):
|
||||
del self._file
|
||||
|
||||
file = property(_get_file, _set_file, _del_file)
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
self._require_file()
|
||||
return self.storage.path(self.name)
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
self._require_file()
|
||||
return self.storage.url(self.name)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
self._require_file()
|
||||
if not self._committed:
|
||||
return self.file.size
|
||||
return self.storage.size(self.name)
|
||||
|
||||
def open(self, mode="rb"):
|
||||
self._require_file()
|
||||
if getattr(self, "_file", None) is None:
|
||||
self.file = self.storage.open(self.name, mode)
|
||||
else:
|
||||
self.file.open(mode)
|
||||
return self
|
||||
|
||||
# open() doesn't alter the file's contents, but it does reset the pointer
|
||||
open.alters_data = True
|
||||
|
||||
# In addition to the standard File API, FieldFiles have extra methods
|
||||
# to further manipulate the underlying file, as well as update the
|
||||
# associated model instance.
|
||||
|
||||
def _set_instance_attribute(self, name, content):
|
||||
setattr(self.instance, self.field.attname, name)
|
||||
|
||||
def save(self, name, content, save=True):
|
||||
name = self.field.generate_filename(self.instance, name)
|
||||
self.name = self.storage.save(name, content, max_length=self.field.max_length)
|
||||
self._set_instance_attribute(self.name, content)
|
||||
self._committed = True
|
||||
|
||||
# Save the object because it has changed, unless save is False
|
||||
if save:
|
||||
self.instance.save()
|
||||
|
||||
save.alters_data = True
|
||||
|
||||
def delete(self, save=True):
|
||||
if not self:
|
||||
return
|
||||
# Only close the file if it's already open, which we know by the
|
||||
# presence of self._file
|
||||
if hasattr(self, "_file"):
|
||||
self.close()
|
||||
del self.file
|
||||
|
||||
self.storage.delete(self.name)
|
||||
|
||||
self.name = None
|
||||
setattr(self.instance, self.field.attname, self.name)
|
||||
self._committed = False
|
||||
|
||||
if save:
|
||||
self.instance.save()
|
||||
|
||||
delete.alters_data = True
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
file = getattr(self, "_file", None)
|
||||
return file is None or file.closed
|
||||
|
||||
def close(self):
|
||||
file = getattr(self, "_file", None)
|
||||
if file is not None:
|
||||
file.close()
|
||||
|
||||
def __getstate__(self):
|
||||
# FieldFile needs access to its associated model field, an instance and
|
||||
# the file's name. Everything else will be restored later, by
|
||||
# FileDescriptor below.
|
||||
return {
|
||||
"name": self.name,
|
||||
"closed": False,
|
||||
"_committed": True,
|
||||
"_file": None,
|
||||
"instance": self.instance,
|
||||
"field": self.field,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
self.storage = self.field.storage
|
||||
|
||||
|
||||
class FileDescriptor(DeferredAttribute):
|
||||
"""
|
||||
The descriptor for the file attribute on the model instance. Return a
|
||||
FieldFile when accessed so you can write code like::
|
||||
|
||||
>>> from myapp.models import MyModel
|
||||
>>> instance = MyModel.objects.get(pk=1)
|
||||
>>> instance.file.size
|
||||
|
||||
Assign a file object on assignment so you can do::
|
||||
|
||||
>>> with open('/path/to/hello.world') as f:
|
||||
... instance.file = File(f)
|
||||
"""
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is None:
|
||||
return self
|
||||
|
||||
# This is slightly complicated, so worth an explanation.
|
||||
# instance.file needs to ultimately return some instance of `File`,
|
||||
# probably a subclass. Additionally, this returned object needs to have
|
||||
# the FieldFile API so that users can easily do things like
|
||||
# instance.file.path and have that delegated to the file storage engine.
|
||||
# Easy enough if we're strict about assignment in __set__, but if you
|
||||
# peek below you can see that we're not. So depending on the current
|
||||
# value of the field we have to dynamically construct some sort of
|
||||
# "thing" to return.
|
||||
|
||||
# The instance dict contains whatever was originally assigned
|
||||
# in __set__.
|
||||
file = super().__get__(instance, cls)
|
||||
|
||||
# If this value is a string (instance.file = "path/to/file") or None
|
||||
# then we simply wrap it with the appropriate attribute class according
|
||||
# to the file field. [This is FieldFile for FileFields and
|
||||
# ImageFieldFile for ImageFields; it's also conceivable that user
|
||||
# subclasses might also want to subclass the attribute class]. This
|
||||
# object understands how to convert a path to a file, and also how to
|
||||
# handle None.
|
||||
if isinstance(file, str) or file is None:
|
||||
attr = self.field.attr_class(instance, self.field, file)
|
||||
instance.__dict__[self.field.attname] = attr
|
||||
|
||||
# If this value is a DatabaseDefault, initialize the attribute class
|
||||
# for this field with its db_default value.
|
||||
elif isinstance(file, DatabaseDefault):
|
||||
attr = self.field.attr_class(instance, self.field, self.field.db_default)
|
||||
instance.__dict__[self.field.attname] = attr
|
||||
|
||||
# Other types of files may be assigned as well, but they need to have
|
||||
# the FieldFile interface added to them. Thus, we wrap any other type of
|
||||
# File inside a FieldFile (well, the field's attr_class, which is
|
||||
# usually FieldFile).
|
||||
elif isinstance(file, File) and not isinstance(file, FieldFile):
|
||||
file_copy = self.field.attr_class(instance, self.field, file.name)
|
||||
file_copy.file = file
|
||||
file_copy._committed = False
|
||||
instance.__dict__[self.field.attname] = file_copy
|
||||
|
||||
# Finally, because of the (some would say boneheaded) way pickle works,
|
||||
# the underlying FieldFile might not actually itself have an associated
|
||||
# file. So we need to reset the details of the FieldFile in those cases.
|
||||
elif isinstance(file, FieldFile) and not hasattr(file, "field"):
|
||||
file.instance = instance
|
||||
file.field = self.field
|
||||
file.storage = self.field.storage
|
||||
|
||||
# Make sure that the instance is correct.
|
||||
elif isinstance(file, FieldFile) and instance is not file.instance:
|
||||
file.instance = instance
|
||||
|
||||
# That was fun, wasn't it?
|
||||
return instance.__dict__[self.field.attname]
|
||||
|
||||
def __set__(self, instance, value):
|
||||
instance.__dict__[self.field.attname] = value
|
||||
|
||||
|
||||
class FileField(Field):
|
||||
# The class to wrap instance attributes in. Accessing the file object off
|
||||
# the instance will always return an instance of attr_class.
|
||||
attr_class = FieldFile
|
||||
|
||||
# The descriptor to use for accessing the attribute off of the class.
|
||||
descriptor_class = FileDescriptor
|
||||
|
||||
description = _("File")
|
||||
|
||||
def __init__(
|
||||
self, verbose_name=None, name=None, upload_to="", storage=None, **kwargs
|
||||
):
|
||||
self._primary_key_set_explicitly = "primary_key" in kwargs
|
||||
|
||||
self.storage = storage or default_storage
|
||||
if callable(self.storage):
|
||||
# Hold a reference to the callable for deconstruct().
|
||||
self._storage_callable = self.storage
|
||||
self.storage = self.storage()
|
||||
if not isinstance(self.storage, Storage):
|
||||
raise TypeError(
|
||||
"%s.storage must be a subclass/instance of %s.%s"
|
||||
% (
|
||||
self.__class__.__qualname__,
|
||||
Storage.__module__,
|
||||
Storage.__qualname__,
|
||||
)
|
||||
)
|
||||
self.upload_to = upload_to
|
||||
|
||||
kwargs.setdefault("max_length", 100)
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return [
|
||||
*super().check(**kwargs),
|
||||
*self._check_primary_key(),
|
||||
*self._check_upload_to(),
|
||||
]
|
||||
|
||||
def _check_primary_key(self):
|
||||
if self._primary_key_set_explicitly:
|
||||
return [
|
||||
checks.Error(
|
||||
"'primary_key' is not a valid argument for a %s."
|
||||
% self.__class__.__name__,
|
||||
obj=self,
|
||||
id="fields.E201",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _check_upload_to(self):
|
||||
if isinstance(self.upload_to, str) and self.upload_to.startswith("/"):
|
||||
return [
|
||||
checks.Error(
|
||||
"%s's 'upload_to' argument must be a relative path, not an "
|
||||
"absolute path." % self.__class__.__name__,
|
||||
obj=self,
|
||||
id="fields.E202",
|
||||
hint="Remove the leading slash.",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if kwargs.get("max_length") == 100:
|
||||
del kwargs["max_length"]
|
||||
kwargs["upload_to"] = self.upload_to
|
||||
storage = getattr(self, "_storage_callable", self.storage)
|
||||
if storage is not default_storage:
|
||||
kwargs["storage"] = storage
|
||||
return name, path, args, kwargs
|
||||
|
||||
def get_internal_type(self):
|
||||
return "FileField"
|
||||
|
||||
def get_prep_value(self, value):
|
||||
value = super().get_prep_value(value)
|
||||
# Need to convert File objects provided via a form to string for
|
||||
# database insertion.
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
file = super().pre_save(model_instance, add)
|
||||
if file.name is None and file._file is not None:
|
||||
exc = FieldError(
|
||||
f"File for {self.name} must have "
|
||||
"the name attribute specified to be saved."
|
||||
)
|
||||
if PY311 and isinstance(file._file, ContentFile):
|
||||
exc.add_note("Pass a 'name' argument to ContentFile.")
|
||||
raise exc
|
||||
|
||||
if file and not file._committed:
|
||||
# Commit the file to storage prior to saving the model
|
||||
file.save(file.name, file.file, save=False)
|
||||
return file
|
||||
|
||||
def contribute_to_class(self, cls, name, **kwargs):
|
||||
super().contribute_to_class(cls, name, **kwargs)
|
||||
setattr(cls, self.attname, self.descriptor_class(self))
|
||||
|
||||
def generate_filename(self, instance, filename):
|
||||
"""
|
||||
Apply (if callable) or prepend (if a string) upload_to to the filename,
|
||||
then delegate further processing of the name to the storage backend.
|
||||
Until the storage layer, all file paths are expected to be Unix style
|
||||
(with forward slashes).
|
||||
"""
|
||||
if callable(self.upload_to):
|
||||
filename = self.upload_to(instance, filename)
|
||||
else:
|
||||
dirname = datetime.datetime.now().strftime(str(self.upload_to))
|
||||
filename = posixpath.join(dirname, filename)
|
||||
filename = validate_file_name(filename, allow_relative_path=True)
|
||||
return self.storage.generate_filename(filename)
|
||||
|
||||
def save_form_data(self, instance, data):
|
||||
# Important: None means "no change", other false value means "clear"
|
||||
# This subtle distinction (rather than a more explicit marker) is
|
||||
# needed because we need to consume values that are also sane for a
|
||||
# regular (non Model-) Form to find in its cleaned_data dictionary.
|
||||
if data is not None:
|
||||
# This value will be converted to str and stored in the
|
||||
# database, so leaving False as-is is not acceptable.
|
||||
setattr(instance, self.name, data or "")
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.FileField,
|
||||
"max_length": self.max_length,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ImageFileDescriptor(FileDescriptor):
|
||||
"""
|
||||
Just like the FileDescriptor, but for ImageFields. The only difference is
|
||||
assigning the width/height to the width_field/height_field, if appropriate.
|
||||
"""
|
||||
|
||||
def __set__(self, instance, value):
|
||||
previous_file = instance.__dict__.get(self.field.attname)
|
||||
super().__set__(instance, value)
|
||||
|
||||
# To prevent recalculating image dimensions when we are instantiating
|
||||
# an object from the database (bug #11084), only update dimensions if
|
||||
# the field had a value before this assignment. Since the default
|
||||
# value for FileField subclasses is an instance of field.attr_class,
|
||||
# previous_file will only be None when we are called from
|
||||
# Model.__init__(). The ImageField.update_dimension_fields method
|
||||
# hooked up to the post_init signal handles the Model.__init__() cases.
|
||||
# Assignment happening outside of Model.__init__() will trigger the
|
||||
# update right here.
|
||||
if previous_file is not None:
|
||||
self.field.update_dimension_fields(instance, force=True)
|
||||
|
||||
|
||||
class ImageFieldFile(ImageFile, FieldFile):
|
||||
def _set_instance_attribute(self, name, content):
|
||||
setattr(self.instance, self.field.attname, content)
|
||||
# Update the name in case generate_filename() or storage.save() changed
|
||||
# it, but bypass the descriptor to avoid re-reading the file.
|
||||
self.instance.__dict__[self.field.attname] = self.name
|
||||
|
||||
def delete(self, save=True):
|
||||
# Clear the image dimensions cache
|
||||
if hasattr(self, "_dimensions_cache"):
|
||||
del self._dimensions_cache
|
||||
super().delete(save)
|
||||
|
||||
|
||||
class ImageField(FileField):
|
||||
attr_class = ImageFieldFile
|
||||
descriptor_class = ImageFileDescriptor
|
||||
description = _("Image")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name=None,
|
||||
name=None,
|
||||
width_field=None,
|
||||
height_field=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.width_field, self.height_field = width_field, height_field
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return [
|
||||
*super().check(**kwargs),
|
||||
*self._check_image_library_installed(),
|
||||
]
|
||||
|
||||
def _check_image_library_installed(self):
|
||||
try:
|
||||
from PIL import Image # NOQA
|
||||
except ImportError:
|
||||
return [
|
||||
checks.Error(
|
||||
"Cannot use ImageField because Pillow is not installed.",
|
||||
hint=(
|
||||
"Get Pillow at https://pypi.org/project/Pillow/ "
|
||||
'or run command "python -m pip install Pillow".'
|
||||
),
|
||||
obj=self,
|
||||
id="fields.E210",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.width_field:
|
||||
kwargs["width_field"] = self.width_field
|
||||
if self.height_field:
|
||||
kwargs["height_field"] = self.height_field
|
||||
return name, path, args, kwargs
|
||||
|
||||
def contribute_to_class(self, cls, name, **kwargs):
|
||||
super().contribute_to_class(cls, name, **kwargs)
|
||||
# Attach update_dimension_fields so that dimension fields declared
|
||||
# after their corresponding image field don't stay cleared by
|
||||
# Model.__init__, see bug #11196.
|
||||
# Only run post-initialization dimension update on non-abstract models
|
||||
# with width_field/height_field.
|
||||
if not cls._meta.abstract and (self.width_field or self.height_field):
|
||||
signals.post_init.connect(self.update_dimension_fields, sender=cls)
|
||||
|
||||
def update_dimension_fields(self, instance, force=False, *args, **kwargs):
|
||||
"""
|
||||
Update field's width and height fields, if defined.
|
||||
|
||||
This method is hooked up to model's post_init signal to update
|
||||
dimensions after instantiating a model instance. However, dimensions
|
||||
won't be updated if the dimensions fields are already populated. This
|
||||
avoids unnecessary recalculation when loading an object from the
|
||||
database.
|
||||
|
||||
Dimensions can be forced to update with force=True, which is how
|
||||
ImageFileDescriptor.__set__ calls this method.
|
||||
"""
|
||||
# Nothing to update if the field doesn't have dimension fields or if
|
||||
# the field is deferred.
|
||||
has_dimension_fields = self.width_field or self.height_field
|
||||
if not has_dimension_fields or self.attname not in instance.__dict__:
|
||||
return
|
||||
|
||||
# getattr will call the ImageFileDescriptor's __get__ method, which
|
||||
# coerces the assigned value into an instance of self.attr_class
|
||||
# (ImageFieldFile in this case).
|
||||
file = getattr(instance, self.attname)
|
||||
|
||||
# Nothing to update if we have no file and not being forced to update.
|
||||
if not file and not force:
|
||||
return
|
||||
|
||||
dimension_fields_filled = not (
|
||||
(self.width_field and not getattr(instance, self.width_field))
|
||||
or (self.height_field and not getattr(instance, self.height_field))
|
||||
)
|
||||
# When both dimension fields have values, we are most likely loading
|
||||
# data from the database or updating an image field that already had
|
||||
# an image stored. In the first case, we don't want to update the
|
||||
# dimension fields because we are already getting their values from the
|
||||
# database. In the second case, we do want to update the dimensions
|
||||
# fields and will skip this return because force will be True since we
|
||||
# were called from ImageFileDescriptor.__set__.
|
||||
if dimension_fields_filled and not force:
|
||||
return
|
||||
|
||||
# file should be an instance of ImageFieldFile or should be None.
|
||||
if file:
|
||||
width = file.width
|
||||
height = file.height
|
||||
else:
|
||||
# No file, so clear dimensions fields.
|
||||
width = None
|
||||
height = None
|
||||
|
||||
# Update the width and height fields.
|
||||
if self.width_field:
|
||||
setattr(instance, self.width_field, width)
|
||||
if self.height_field:
|
||||
setattr(instance, self.height_field, height)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.ImageField,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,197 @@
|
|||
from django.core import checks
|
||||
from django.db import connections, router
|
||||
from django.db.models.sql import Query
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from . import NOT_PROVIDED, Field
|
||||
|
||||
__all__ = ["GeneratedField"]
|
||||
|
||||
|
||||
class GeneratedField(Field):
|
||||
generated = True
|
||||
db_returning = True
|
||||
|
||||
_query = None
|
||||
output_field = None
|
||||
|
||||
def __init__(self, *, expression, output_field, db_persist=None, **kwargs):
|
||||
if kwargs.setdefault("editable", False):
|
||||
raise ValueError("GeneratedField cannot be editable.")
|
||||
if not kwargs.setdefault("blank", True):
|
||||
raise ValueError("GeneratedField must be blank.")
|
||||
if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
|
||||
raise ValueError("GeneratedField cannot have a default.")
|
||||
if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
|
||||
raise ValueError("GeneratedField cannot have a database default.")
|
||||
if db_persist not in (True, False):
|
||||
raise ValueError("GeneratedField.db_persist must be True or False.")
|
||||
|
||||
self.expression = expression
|
||||
self.output_field = output_field
|
||||
self.db_persist = db_persist
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@cached_property
|
||||
def cached_col(self):
|
||||
from django.db.models.expressions import Col
|
||||
|
||||
return Col(self.model._meta.db_table, self, self.output_field)
|
||||
|
||||
def get_col(self, alias, output_field=None):
|
||||
if alias != self.model._meta.db_table and output_field in (None, self):
|
||||
output_field = self.output_field
|
||||
return super().get_col(alias, output_field)
|
||||
|
||||
def contribute_to_class(self, *args, **kwargs):
|
||||
super().contribute_to_class(*args, **kwargs)
|
||||
|
||||
self._query = Query(model=self.model, alias_cols=False)
|
||||
# Register lookups from the output_field class.
|
||||
for lookup_name, lookup in self.output_field.get_class_lookups().items():
|
||||
self.register_lookup(lookup, lookup_name=lookup_name)
|
||||
|
||||
def generated_sql(self, connection):
|
||||
compiler = connection.ops.compiler("SQLCompiler")(
|
||||
self._query, connection=connection, using=None
|
||||
)
|
||||
resolved_expression = self.expression.resolve_expression(
|
||||
self._query, allow_joins=False
|
||||
)
|
||||
sql, params = compiler.compile(resolved_expression)
|
||||
if (
|
||||
getattr(self.expression, "conditional", False)
|
||||
and not connection.features.supports_boolean_expr_in_select_clause
|
||||
):
|
||||
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
|
||||
return sql, params
|
||||
|
||||
def check(self, **kwargs):
|
||||
databases = kwargs.get("databases") or []
|
||||
errors = [
|
||||
*super().check(**kwargs),
|
||||
*self._check_supported(databases),
|
||||
*self._check_persistence(databases),
|
||||
]
|
||||
output_field_clone = self.output_field.clone()
|
||||
output_field_clone.model = self.model
|
||||
output_field_checks = output_field_clone.check(databases=databases)
|
||||
if output_field_checks:
|
||||
separator = "\n "
|
||||
error_messages = separator.join(
|
||||
f"{output_check.msg} ({output_check.id})"
|
||||
for output_check in output_field_checks
|
||||
if isinstance(output_check, checks.Error)
|
||||
)
|
||||
if error_messages:
|
||||
errors.append(
|
||||
checks.Error(
|
||||
"GeneratedField.output_field has errors:"
|
||||
f"{separator}{error_messages}",
|
||||
obj=self,
|
||||
id="fields.E223",
|
||||
)
|
||||
)
|
||||
warning_messages = separator.join(
|
||||
f"{output_check.msg} ({output_check.id})"
|
||||
for output_check in output_field_checks
|
||||
if isinstance(output_check, checks.Warning)
|
||||
)
|
||||
if warning_messages:
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
"GeneratedField.output_field has warnings:"
|
||||
f"{separator}{warning_messages}",
|
||||
obj=self,
|
||||
id="fields.W224",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def _check_supported(self, databases):
|
||||
errors = []
|
||||
for db in databases:
|
||||
if not router.allow_migrate_model(db, self.model):
|
||||
continue
|
||||
connection = connections[db]
|
||||
if (
|
||||
self.model._meta.required_db_vendor
|
||||
and self.model._meta.required_db_vendor != connection.vendor
|
||||
):
|
||||
continue
|
||||
if not (
|
||||
connection.features.supports_virtual_generated_columns
|
||||
or "supports_stored_generated_columns"
|
||||
in self.model._meta.required_db_features
|
||||
) and not (
|
||||
connection.features.supports_stored_generated_columns
|
||||
or "supports_virtual_generated_columns"
|
||||
in self.model._meta.required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
f"{connection.display_name} does not support GeneratedFields.",
|
||||
obj=self,
|
||||
id="fields.E220",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def _check_persistence(self, databases):
|
||||
errors = []
|
||||
for db in databases:
|
||||
if not router.allow_migrate_model(db, self.model):
|
||||
continue
|
||||
connection = connections[db]
|
||||
if (
|
||||
self.model._meta.required_db_vendor
|
||||
and self.model._meta.required_db_vendor != connection.vendor
|
||||
):
|
||||
continue
|
||||
if not self.db_persist and not (
|
||||
connection.features.supports_virtual_generated_columns
|
||||
or "supports_virtual_generated_columns"
|
||||
in self.model._meta.required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
f"{connection.display_name} does not support non-persisted "
|
||||
"GeneratedFields.",
|
||||
obj=self,
|
||||
id="fields.E221",
|
||||
hint="Set db_persist=True on the field.",
|
||||
)
|
||||
)
|
||||
if self.db_persist and not (
|
||||
connection.features.supports_stored_generated_columns
|
||||
or "supports_stored_generated_columns"
|
||||
in self.model._meta.required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
f"{connection.display_name} does not support persisted "
|
||||
"GeneratedFields.",
|
||||
obj=self,
|
||||
id="fields.E222",
|
||||
hint="Set db_persist=False on the field.",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
del kwargs["blank"]
|
||||
del kwargs["editable"]
|
||||
kwargs["db_persist"] = self.db_persist
|
||||
kwargs["expression"] = self.expression
|
||||
kwargs["output_field"] = self.output_field
|
||||
return name, path, args, kwargs
|
||||
|
||||
def get_internal_type(self):
|
||||
return self.output_field.get_internal_type()
|
||||
|
||||
def db_parameters(self, connection):
|
||||
return self.output_field.db_parameters(connection)
|
||||
|
||||
def db_type_parameters(self, connection):
|
||||
return self.output_field.db_type_parameters(connection)
|
||||
|
|
@ -0,0 +1,664 @@
|
|||
import json
|
||||
|
||||
from django import forms
|
||||
from django.core import checks, exceptions
|
||||
from django.db import NotSupportedError, connections, router
|
||||
from django.db.models import expressions, lookups
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.fields import TextField
|
||||
from django.db.models.lookups import (
|
||||
FieldGetDbPrepValueMixin,
|
||||
PostgresOperatorLookup,
|
||||
Transform,
|
||||
)
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from . import Field
|
||||
from .mixins import CheckFieldDefaultMixin
|
||||
|
||||
__all__ = ["JSONField"]
|
||||
|
||||
|
||||
class JSONField(CheckFieldDefaultMixin, Field):
|
||||
empty_strings_allowed = False
|
||||
description = _("A JSON object")
|
||||
default_error_messages = {
|
||||
"invalid": _("Value must be valid JSON."),
|
||||
}
|
||||
_default_hint = ("dict", "{}")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name=None,
|
||||
name=None,
|
||||
encoder=None,
|
||||
decoder=None,
|
||||
**kwargs,
|
||||
):
|
||||
if encoder and not callable(encoder):
|
||||
raise ValueError("The encoder parameter must be a callable object.")
|
||||
if decoder and not callable(decoder):
|
||||
raise ValueError("The decoder parameter must be a callable object.")
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
databases = kwargs.get("databases") or []
|
||||
errors.extend(self._check_supported(databases))
|
||||
return errors
|
||||
|
||||
def _check_supported(self, databases):
|
||||
errors = []
|
||||
for db in databases:
|
||||
if not router.allow_migrate_model(db, self.model):
|
||||
continue
|
||||
connection = connections[db]
|
||||
if (
|
||||
self.model._meta.required_db_vendor
|
||||
and self.model._meta.required_db_vendor != connection.vendor
|
||||
):
|
||||
continue
|
||||
if not (
|
||||
"supports_json_field" in self.model._meta.required_db_features
|
||||
or connection.features.supports_json_field
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
"%s does not support JSONFields." % connection.display_name,
|
||||
obj=self.model,
|
||||
id="fields.E180",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.encoder is not None:
|
||||
kwargs["encoder"] = self.encoder
|
||||
if self.decoder is not None:
|
||||
kwargs["decoder"] = self.decoder
|
||||
return name, path, args, kwargs
|
||||
|
||||
def from_db_value(self, value, expression, connection):
|
||||
if value is None:
|
||||
return value
|
||||
# Some backends (SQLite at least) extract non-string values in their
|
||||
# SQL datatypes.
|
||||
if isinstance(expression, KeyTransform) and not isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.loads(value, cls=self.decoder)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
def get_internal_type(self):
|
||||
return "JSONField"
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if not prepared:
|
||||
value = self.get_prep_value(value)
|
||||
return connection.ops.adapt_json_value(value, self.encoder)
|
||||
|
||||
def get_db_prep_save(self, value, connection):
|
||||
# This slightly involved logic is to allow for `None` to be used to
|
||||
# store SQL `NULL` while `Value(None, JSONField())` can be used to
|
||||
# store JSON `null` while preventing compilable `as_sql` values from
|
||||
# making their way to `get_db_prep_value`, which is what the `super()`
|
||||
# implementation does.
|
||||
if value is None:
|
||||
return value
|
||||
if (
|
||||
isinstance(value, expressions.Value)
|
||||
and value.value is None
|
||||
and isinstance(value.output_field, JSONField)
|
||||
):
|
||||
value = None
|
||||
return super().get_db_prep_save(value, connection)
|
||||
|
||||
def get_transform(self, name):
|
||||
transform = super().get_transform(name)
|
||||
if transform:
|
||||
return transform
|
||||
return KeyTransformFactory(name)
|
||||
|
||||
def validate(self, value, model_instance):
|
||||
super().validate(value, model_instance)
|
||||
try:
|
||||
json.dumps(value, cls=self.encoder)
|
||||
except TypeError:
|
||||
raise exceptions.ValidationError(
|
||||
self.error_messages["invalid"],
|
||||
code="invalid",
|
||||
params={"value": value},
|
||||
)
|
||||
|
||||
def value_to_string(self, obj):
|
||||
return self.value_from_object(obj)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.JSONField,
|
||||
"encoder": self.encoder,
|
||||
"decoder": self.decoder,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def compile_json_path(key_transforms, include_root=True):
|
||||
path = ["$"] if include_root else []
|
||||
for key_transform in key_transforms:
|
||||
try:
|
||||
num = int(key_transform)
|
||||
except ValueError: # non-integer
|
||||
path.append(".")
|
||||
path.append(json.dumps(key_transform))
|
||||
else:
|
||||
path.append("[%s]" % num)
|
||||
return "".join(path)
|
||||
|
||||
|
||||
class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
|
||||
lookup_name = "contains"
|
||||
postgres_operator = "@>"
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_json_field_contains:
|
||||
raise NotSupportedError(
|
||||
"contains lookup is not supported on this database backend."
|
||||
)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(lhs_params) + tuple(rhs_params)
|
||||
return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
|
||||
|
||||
|
||||
class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
|
||||
lookup_name = "contained_by"
|
||||
postgres_operator = "<@"
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_json_field_contains:
|
||||
raise NotSupportedError(
|
||||
"contained_by lookup is not supported on this database backend."
|
||||
)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(rhs_params) + tuple(lhs_params)
|
||||
return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
|
||||
|
||||
|
||||
class HasKeyLookup(PostgresOperatorLookup):
|
||||
logical_operator = None
|
||||
|
||||
def compile_json_path_final_key(self, key_transform):
|
||||
# Compile the final key without interpreting ints as array elements.
|
||||
return ".%s" % json.dumps(key_transform)
|
||||
|
||||
def _as_sql_parts(self, compiler, connection):
|
||||
# Process JSON path from the left-hand side.
|
||||
if isinstance(self.lhs, KeyTransform):
|
||||
lhs_sql, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
|
||||
compiler, connection
|
||||
)
|
||||
lhs_json_path = compile_json_path(lhs_key_transforms)
|
||||
else:
|
||||
lhs_sql, lhs_params = self.process_lhs(compiler, connection)
|
||||
lhs_json_path = "$"
|
||||
# Process JSON path from the right-hand side.
|
||||
rhs = self.rhs
|
||||
if not isinstance(rhs, (list, tuple)):
|
||||
rhs = [rhs]
|
||||
for key in rhs:
|
||||
if isinstance(key, KeyTransform):
|
||||
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
|
||||
else:
|
||||
rhs_key_transforms = [key]
|
||||
*rhs_key_transforms, final_key = rhs_key_transforms
|
||||
rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
|
||||
rhs_json_path += self.compile_json_path_final_key(final_key)
|
||||
yield lhs_sql, lhs_params, lhs_json_path + rhs_json_path
|
||||
|
||||
def _combine_sql_parts(self, parts):
|
||||
# Add condition for each key.
|
||||
if self.logical_operator:
|
||||
return "(%s)" % self.logical_operator.join(parts)
|
||||
return "".join(parts)
|
||||
|
||||
def as_sql(self, compiler, connection, template=None):
|
||||
sql_parts = []
|
||||
params = []
|
||||
for lhs_sql, lhs_params, rhs_json_path in self._as_sql_parts(
|
||||
compiler, connection
|
||||
):
|
||||
sql_parts.append(template % (lhs_sql, "%s"))
|
||||
params.extend(lhs_params + [rhs_json_path])
|
||||
return self._combine_sql_parts(sql_parts), tuple(params)
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %s)"
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
# Use a custom delimiter to prevent the JSON path from escaping the SQL
|
||||
# literal. See comment in KeyTransform.
|
||||
template = "JSON_EXISTS(%s, q'\uffff%s\uffff')"
|
||||
sql_parts = []
|
||||
params = []
|
||||
for lhs_sql, lhs_params, rhs_json_path in self._as_sql_parts(
|
||||
compiler, connection
|
||||
):
|
||||
# Add right-hand-side directly into SQL because it cannot be passed
|
||||
# as bind variables to JSON_EXISTS. It might result in invalid
|
||||
# queries but it is assumed that it cannot be evaded because the
|
||||
# path is JSON serialized.
|
||||
sql_parts.append(template % (lhs_sql, rhs_json_path))
|
||||
params.extend(lhs_params)
|
||||
return self._combine_sql_parts(sql_parts), tuple(params)
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
if isinstance(self.rhs, KeyTransform):
|
||||
*_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
|
||||
for key in rhs_key_transforms[:-1]:
|
||||
self.lhs = KeyTransform(key, self.lhs)
|
||||
self.rhs = rhs_key_transforms[-1]
|
||||
return super().as_postgresql(compiler, connection)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="JSON_TYPE(%s, %s) IS NOT NULL"
|
||||
)
|
||||
|
||||
|
||||
class HasKey(HasKeyLookup):
|
||||
lookup_name = "has_key"
|
||||
postgres_operator = "?"
|
||||
prepare_rhs = False
|
||||
|
||||
|
||||
class HasKeys(HasKeyLookup):
|
||||
lookup_name = "has_keys"
|
||||
postgres_operator = "?&"
|
||||
logical_operator = " AND "
|
||||
|
||||
def get_prep_lookup(self):
|
||||
return [str(item) for item in self.rhs]
|
||||
|
||||
|
||||
class HasAnyKeys(HasKeys):
|
||||
lookup_name = "has_any_keys"
|
||||
postgres_operator = "?|"
|
||||
logical_operator = " OR "
|
||||
|
||||
|
||||
class HasKeyOrArrayIndex(HasKey):
|
||||
def compile_json_path_final_key(self, key_transform):
|
||||
return compile_json_path([key_transform], include_root=False)
|
||||
|
||||
|
||||
class CaseInsensitiveMixin:
|
||||
"""
|
||||
Mixin to allow case-insensitive comparison of JSON values on MySQL.
|
||||
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
|
||||
Because utf8mb4_bin is a binary collation, comparison of JSON values is
|
||||
case-sensitive.
|
||||
"""
|
||||
|
||||
def process_lhs(self, compiler, connection):
|
||||
lhs, lhs_params = super().process_lhs(compiler, connection)
|
||||
if connection.vendor == "mysql":
|
||||
return "LOWER(%s)" % lhs, lhs_params
|
||||
return lhs, lhs_params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == "mysql":
|
||||
return "LOWER(%s)" % rhs, rhs_params
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class JSONExact(lookups.Exact):
|
||||
can_use_none_as_rhs = True
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
# Treat None lookup values as null.
|
||||
if rhs == "%s" and rhs_params == [None]:
|
||||
rhs_params = ["null"]
|
||||
if connection.vendor == "mysql":
|
||||
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
|
||||
rhs %= tuple(func)
|
||||
return rhs, rhs_params
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
if connection.features.supports_primitives_in_json_field:
|
||||
lhs = f"JSON({lhs})"
|
||||
rhs = f"JSON({rhs})"
|
||||
return f"JSON_EQUAL({lhs}, {rhs} ERROR ON ERROR)", (*lhs_params, *rhs_params)
|
||||
|
||||
|
||||
class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
|
||||
pass
|
||||
|
||||
|
||||
JSONField.register_lookup(DataContains)
|
||||
JSONField.register_lookup(ContainedBy)
|
||||
JSONField.register_lookup(HasKey)
|
||||
JSONField.register_lookup(HasKeys)
|
||||
JSONField.register_lookup(HasAnyKeys)
|
||||
JSONField.register_lookup(JSONExact)
|
||||
JSONField.register_lookup(JSONIContains)
|
||||
|
||||
|
||||
class KeyTransform(Transform):
|
||||
postgres_operator = "->"
|
||||
postgres_nested_operator = "#>"
|
||||
|
||||
def __init__(self, key_name, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.key_name = str(key_name)
|
||||
|
||||
def preprocess_lhs(self, compiler, connection):
|
||||
key_transforms = [self.key_name]
|
||||
previous = self.lhs
|
||||
while isinstance(previous, KeyTransform):
|
||||
key_transforms.insert(0, previous.key_name)
|
||||
previous = previous.lhs
|
||||
lhs, params = compiler.compile(previous)
|
||||
if connection.vendor == "oracle":
|
||||
# Escape string-formatting.
|
||||
key_transforms = [key.replace("%", "%%") for key in key_transforms]
|
||||
return lhs, params, key_transforms
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
if connection.features.supports_primitives_in_json_field:
|
||||
sql = (
|
||||
"COALESCE("
|
||||
"JSON_VALUE(%s, q'\uffff%s\uffff'),"
|
||||
"JSON_QUERY(%s, q'\uffff%s\uffff' DISALLOW SCALARS)"
|
||||
")"
|
||||
)
|
||||
else:
|
||||
sql = (
|
||||
"COALESCE("
|
||||
"JSON_QUERY(%s, q'\uffff%s\uffff'),"
|
||||
"JSON_VALUE(%s, q'\uffff%s\uffff')"
|
||||
")"
|
||||
)
|
||||
# Add paths directly into SQL because path expressions cannot be passed
|
||||
# as bind variables on Oracle. Use a custom delimiter to prevent the
|
||||
# JSON path from escaping the SQL literal. Each key in the JSON path is
|
||||
# passed through json.dumps() with ensure_ascii=True (the default),
|
||||
# which converts the delimiter into the escaped \uffff format. This
|
||||
# ensures that the delimiter is not present in the JSON path.
|
||||
return sql % ((lhs, json_path) * 2), tuple(params) * 2
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
if len(key_transforms) > 1:
|
||||
sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
|
||||
return sql, tuple(params) + (key_transforms,)
|
||||
try:
|
||||
lookup = int(self.key_name)
|
||||
except ValueError:
|
||||
lookup = self.key_name
|
||||
return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
datatype_values = ",".join(
|
||||
[repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
|
||||
)
|
||||
return (
|
||||
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
|
||||
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
|
||||
) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
|
||||
|
||||
|
||||
class KeyTextTransform(KeyTransform):
|
||||
postgres_operator = "->>"
|
||||
postgres_nested_operator = "#>>"
|
||||
output_field = TextField()
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
if connection.mysql_is_mariadb:
|
||||
# MariaDB doesn't support -> and ->> operators (see MDEV-13594).
|
||||
sql, params = super().as_mysql(compiler, connection)
|
||||
return "JSON_UNQUOTE(%s)" % sql, params
|
||||
else:
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
|
||||
|
||||
@classmethod
|
||||
def from_lookup(cls, lookup):
|
||||
transform, *keys = lookup.split(LOOKUP_SEP)
|
||||
if not keys:
|
||||
raise ValueError("Lookup must contain key or index transforms.")
|
||||
for key in keys:
|
||||
transform = cls(key, transform)
|
||||
return transform
|
||||
|
||||
|
||||
KT = KeyTextTransform.from_lookup
|
||||
|
||||
|
||||
class KeyTransformTextLookupMixin:
|
||||
"""
|
||||
Mixin for combining with a lookup expecting a text lhs from a JSONField
|
||||
key lookup. On PostgreSQL, make use of the ->> operator instead of casting
|
||||
key values to text and performing the lookup on the resulting
|
||||
representation.
|
||||
"""
|
||||
|
||||
def __init__(self, key_transform, *args, **kwargs):
|
||||
if not isinstance(key_transform, KeyTransform):
|
||||
raise TypeError(
|
||||
"Transform should be an instance of KeyTransform in order to "
|
||||
"use this lookup."
|
||||
)
|
||||
key_text_transform = KeyTextTransform(
|
||||
key_transform.key_name,
|
||||
*key_transform.source_expressions,
|
||||
**key_transform.extra,
|
||||
)
|
||||
super().__init__(key_text_transform, *args, **kwargs)
|
||||
|
||||
|
||||
class KeyTransformIsNull(lookups.IsNull):
|
||||
# key__isnull=False is the same as has_key='key'
|
||||
def as_oracle(self, compiler, connection):
|
||||
sql, params = HasKeyOrArrayIndex(
|
||||
self.lhs.lhs,
|
||||
self.lhs.key_name,
|
||||
).as_oracle(compiler, connection)
|
||||
if not self.rhs:
|
||||
return sql, params
|
||||
# Column doesn't have a key or IS NULL.
|
||||
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
|
||||
return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
template = "JSON_TYPE(%s, %s) IS NULL"
|
||||
if not self.rhs:
|
||||
template = "JSON_TYPE(%s, %s) IS NOT NULL"
|
||||
return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=template,
|
||||
)
|
||||
|
||||
|
||||
class KeyTransformIn(lookups.In):
|
||||
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||||
sql, params = super().resolve_expression_parameter(
|
||||
compiler,
|
||||
connection,
|
||||
sql,
|
||||
param,
|
||||
)
|
||||
if (
|
||||
not hasattr(param, "as_sql")
|
||||
and not connection.features.has_native_json_field
|
||||
):
|
||||
if connection.vendor == "oracle":
|
||||
value = json.loads(param)
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
if isinstance(value, (list, dict)):
|
||||
sql %= "JSON_QUERY"
|
||||
else:
|
||||
sql %= "JSON_VALUE"
|
||||
elif connection.vendor == "mysql" or (
|
||||
connection.vendor == "sqlite"
|
||||
and params[0] not in connection.ops.jsonfield_datatype_values
|
||||
):
|
||||
sql = "JSON_EXTRACT(%s, '$')"
|
||||
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
|
||||
sql = "JSON_UNQUOTE(%s)" % sql
|
||||
return sql, params
|
||||
|
||||
|
||||
class KeyTransformExact(JSONExact):
|
||||
def process_rhs(self, compiler, connection):
|
||||
if isinstance(self.rhs, KeyTransform):
|
||||
return super(lookups.Exact, self).process_rhs(compiler, connection)
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == "oracle":
|
||||
func = []
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
for value in rhs_params:
|
||||
value = json.loads(value)
|
||||
if isinstance(value, (list, dict)):
|
||||
func.append(sql % "JSON_QUERY")
|
||||
else:
|
||||
func.append(sql % "JSON_VALUE")
|
||||
rhs %= tuple(func)
|
||||
elif connection.vendor == "sqlite":
|
||||
func = []
|
||||
for value in rhs_params:
|
||||
if value in connection.ops.jsonfield_datatype_values:
|
||||
func.append("%s")
|
||||
else:
|
||||
func.append("JSON_EXTRACT(%s, '$')")
|
||||
rhs %= tuple(func)
|
||||
return rhs, rhs_params
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if rhs_params == ["null"]:
|
||||
# Field has key and it's NULL.
|
||||
has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
|
||||
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
|
||||
is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
|
||||
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
|
||||
return (
|
||||
"%s AND %s" % (has_key_sql, is_null_sql),
|
||||
tuple(has_key_params) + tuple(is_null_params),
|
||||
)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class KeyTransformIExact(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIContains(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIStartsWith(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIEndsWith(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIRegex(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformNumericLookupMixin:
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if not connection.features.has_native_json_field:
|
||||
rhs_params = [json.loads(value) for value in rhs_params]
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
KeyTransform.register_lookup(KeyTransformIn)
|
||||
KeyTransform.register_lookup(KeyTransformExact)
|
||||
KeyTransform.register_lookup(KeyTransformIExact)
|
||||
KeyTransform.register_lookup(KeyTransformIsNull)
|
||||
KeyTransform.register_lookup(KeyTransformIContains)
|
||||
KeyTransform.register_lookup(KeyTransformStartsWith)
|
||||
KeyTransform.register_lookup(KeyTransformIStartsWith)
|
||||
KeyTransform.register_lookup(KeyTransformEndsWith)
|
||||
KeyTransform.register_lookup(KeyTransformIEndsWith)
|
||||
KeyTransform.register_lookup(KeyTransformRegex)
|
||||
KeyTransform.register_lookup(KeyTransformIRegex)
|
||||
|
||||
KeyTransform.register_lookup(KeyTransformLt)
|
||||
KeyTransform.register_lookup(KeyTransformLte)
|
||||
KeyTransform.register_lookup(KeyTransformGt)
|
||||
KeyTransform.register_lookup(KeyTransformGte)
|
||||
|
||||
|
||||
class KeyTransformFactory:
|
||||
def __init__(self, key_name):
|
||||
self.key_name = key_name
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return KeyTransform(self.key_name, *args, **kwargs)
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
import warnings
|
||||
|
||||
from django.core import checks
|
||||
from django.utils.deprecation import RemovedInDjango60Warning
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
NOT_PROVIDED = object()
|
||||
|
||||
|
||||
class FieldCacheMixin:
|
||||
"""
|
||||
An API for working with the model's fields value cache.
|
||||
|
||||
Subclasses must set self.cache_name to a unique entry for the cache -
|
||||
typically the field’s name.
|
||||
"""
|
||||
|
||||
# RemovedInDjango60Warning.
|
||||
def get_cache_name(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@cached_property
|
||||
def cache_name(self):
|
||||
# RemovedInDjango60Warning: when the deprecation ends, replace with:
|
||||
# raise NotImplementedError
|
||||
cache_name = self.get_cache_name()
|
||||
warnings.warn(
|
||||
f"Override {self.__class__.__qualname__}.cache_name instead of "
|
||||
"get_cache_name().",
|
||||
RemovedInDjango60Warning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return cache_name
|
||||
|
||||
def get_cached_value(self, instance, default=NOT_PROVIDED):
|
||||
try:
|
||||
return instance._state.fields_cache[self.cache_name]
|
||||
except KeyError:
|
||||
if default is NOT_PROVIDED:
|
||||
raise
|
||||
return default
|
||||
|
||||
def is_cached(self, instance):
|
||||
return self.cache_name in instance._state.fields_cache
|
||||
|
||||
def set_cached_value(self, instance, value):
|
||||
instance._state.fields_cache[self.cache_name] = value
|
||||
|
||||
def delete_cached_value(self, instance):
|
||||
del instance._state.fields_cache[self.cache_name]
|
||||
|
||||
|
||||
class CheckFieldDefaultMixin:
|
||||
_default_hint = ("<valid default>", "<invalid default>")
|
||||
|
||||
def _check_default(self):
|
||||
if (
|
||||
self.has_default()
|
||||
and self.default is not None
|
||||
and not callable(self.default)
|
||||
):
|
||||
return [
|
||||
checks.Warning(
|
||||
"%s default should be a callable instead of an instance "
|
||||
"so that it's not shared between all field instances."
|
||||
% (self.__class__.__name__,),
|
||||
hint=(
|
||||
"Use a callable instead, e.g., use `%s` instead of "
|
||||
"`%s`." % self._default_hint
|
||||
),
|
||||
obj=self,
|
||||
id="fields.E010",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
errors.extend(self._check_default())
|
||||
return errors
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
"""
|
||||
Field-like classes that aren't really fields. It's easier to use objects that
|
||||
have the same attributes as fields sometimes (avoids a lot of special casing).
|
||||
"""
|
||||
|
||||
from django.db.models import fields
|
||||
|
||||
|
||||
class OrderWrt(fields.IntegerField):
|
||||
"""
|
||||
A proxy for the _order database field that is used when
|
||||
Meta.order_with_respect_to is specified.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["name"] = "_order"
|
||||
kwargs["editable"] = False
|
||||
super().__init__(*args, **kwargs)
|
||||
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,151 @@
|
|||
from django.db.models.expressions import ColPairs
|
||||
from django.db.models.fields import composite
|
||||
from django.db.models.fields.tuple_lookups import TupleIn, tuple_lookups
|
||||
from django.db.models.lookups import (
|
||||
Exact,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
In,
|
||||
IsNull,
|
||||
LessThan,
|
||||
LessThanOrEqual,
|
||||
)
|
||||
|
||||
|
||||
def get_normalized_value(value, lhs):
|
||||
from django.db.models import Model
|
||||
|
||||
if isinstance(value, Model):
|
||||
if not value._is_pk_set():
|
||||
raise ValueError("Model instances passed to related filters must be saved.")
|
||||
value_list = []
|
||||
sources = composite.unnest(lhs.output_field.path_infos[-1].target_fields)
|
||||
for source in sources:
|
||||
while not isinstance(value, source.model) and source.remote_field:
|
||||
source = source.remote_field.model._meta.get_field(
|
||||
source.remote_field.field_name
|
||||
)
|
||||
try:
|
||||
value_list.append(getattr(value, source.attname))
|
||||
except AttributeError:
|
||||
# A case like Restaurant.objects.filter(place=restaurant_instance),
|
||||
# where place is a OneToOneField and the primary key of Restaurant.
|
||||
pk = value.pk
|
||||
return pk if isinstance(pk, tuple) else (pk,)
|
||||
return tuple(value_list)
|
||||
if not isinstance(value, tuple):
|
||||
return (value,)
|
||||
return value
|
||||
|
||||
|
||||
class RelatedIn(In):
|
||||
def get_prep_lookup(self):
|
||||
from django.db.models.sql.query import Query # avoid circular import
|
||||
|
||||
if isinstance(self.lhs, ColPairs):
|
||||
if (
|
||||
isinstance(self.rhs, Query)
|
||||
and not self.rhs.has_select_fields
|
||||
and self.lhs.output_field.related_model is self.rhs.model
|
||||
):
|
||||
self.rhs.set_values([f.name for f in self.lhs.sources])
|
||||
else:
|
||||
if self.rhs_is_direct_value():
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
|
||||
# We need to run the related field's get_prep_value(). Consider
|
||||
# case ForeignKey to IntegerField given value 'abc'. The
|
||||
# ForeignKey itself doesn't have validation for non-integers,
|
||||
# so we must run validation using the target field.
|
||||
if hasattr(self.lhs.output_field, "path_infos"):
|
||||
# Run the target field's get_prep_value. We can safely
|
||||
# assume there is only one as we don't get to the direct
|
||||
# value branch otherwise.
|
||||
target_field = self.lhs.output_field.path_infos[-1].target_fields[
|
||||
-1
|
||||
]
|
||||
self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
|
||||
elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
|
||||
self.lhs.field.target_field, "primary_key", False
|
||||
):
|
||||
if (
|
||||
getattr(self.lhs.output_field, "primary_key", False)
|
||||
and self.lhs.output_field.model == self.rhs.model
|
||||
):
|
||||
# A case like
|
||||
# Restaurant.objects.filter(place__in=restaurant_qs), where
|
||||
# place is a OneToOneField and the primary key of
|
||||
# Restaurant.
|
||||
target_field = self.lhs.field.name
|
||||
else:
|
||||
target_field = self.lhs.field.target_field.name
|
||||
self.rhs.set_values([target_field])
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, ColPairs):
|
||||
if self.rhs_is_direct_value():
|
||||
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
|
||||
lookup = TupleIn(self.lhs, values)
|
||||
else:
|
||||
lookup = TupleIn(self.lhs, self.rhs)
|
||||
return compiler.compile(lookup)
|
||||
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class RelatedLookupMixin:
|
||||
def get_prep_lookup(self):
|
||||
if not isinstance(self.lhs, ColPairs) and not hasattr(
|
||||
self.rhs, "resolve_expression"
|
||||
):
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
|
||||
# We need to run the related field's get_prep_value(). Consider case
|
||||
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
|
||||
# doesn't have validation for non-integers, so we must run validation
|
||||
# using the target field.
|
||||
if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
|
||||
# Get the target field. We can safely assume there is only one
|
||||
# as we don't get to the direct value branch otherwise.
|
||||
target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
|
||||
self.rhs = target_field.get_prep_value(self.rhs)
|
||||
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, ColPairs):
|
||||
if not self.rhs_is_direct_value():
|
||||
raise ValueError(
|
||||
f"'{self.lookup_name}' doesn't support multi-column subqueries."
|
||||
)
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)
|
||||
lookup_class = tuple_lookups[self.lookup_name]
|
||||
lookup = lookup_class(self.lhs, self.rhs)
|
||||
return compiler.compile(lookup)
|
||||
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class RelatedExact(RelatedLookupMixin, Exact):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedLessThan(RelatedLookupMixin, LessThan):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedIsNull(RelatedLookupMixin, IsNull):
|
||||
pass
|
||||
|
|
@ -0,0 +1,416 @@
|
|||
"""
|
||||
"Rel objects" for related fields.
|
||||
|
||||
"Rel objects" (for lack of a better name) carry information about the relation
|
||||
modeled by a related field and provide some utility functions. They're stored
|
||||
in the ``remote_field`` attribute of the field.
|
||||
|
||||
They also act as reverse fields for the purposes of the Meta API because
|
||||
they're the closest concept currently available.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from django.core import exceptions
|
||||
from django.utils.deprecation import RemovedInDjango60Warning
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.hashable import make_hashable
|
||||
|
||||
from . import BLANK_CHOICE_DASH
|
||||
from .mixins import FieldCacheMixin
|
||||
|
||||
|
||||
class ForeignObjectRel(FieldCacheMixin):
|
||||
"""
|
||||
Used by ForeignObject to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
# Field flags
|
||||
auto_created = True
|
||||
concrete = False
|
||||
editable = False
|
||||
is_relation = True
|
||||
|
||||
# Reverse relations are always nullable (Django can't enforce that a
|
||||
# foreign key on the related model points to this model).
|
||||
null = True
|
||||
empty_strings_allowed = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
self.field = field
|
||||
self.model = to
|
||||
self.related_name = related_name
|
||||
self.related_query_name = related_query_name
|
||||
self.limit_choices_to = {} if limit_choices_to is None else limit_choices_to
|
||||
self.parent_link = parent_link
|
||||
self.on_delete = on_delete
|
||||
|
||||
self.symmetrical = False
|
||||
self.multiple = True
|
||||
|
||||
# Some of the following cached_properties can't be initialized in
|
||||
# __init__ as the field doesn't have its model yet. Calling these methods
|
||||
# before field.contribute_to_class() has been called will result in
|
||||
# AttributeError
|
||||
@cached_property
|
||||
def hidden(self):
|
||||
"""Should the related object be hidden?"""
|
||||
return bool(self.related_name) and self.related_name[-1] == "+"
|
||||
|
||||
@cached_property
|
||||
def name(self):
|
||||
return self.field.related_query_name()
|
||||
|
||||
@property
|
||||
def remote_field(self):
|
||||
return self.field
|
||||
|
||||
@property
|
||||
def target_field(self):
|
||||
"""
|
||||
When filtering against this relation, return the field on the remote
|
||||
model against which the filtering should happen.
|
||||
"""
|
||||
target_fields = self.path_infos[-1].target_fields
|
||||
if len(target_fields) > 1:
|
||||
raise exceptions.FieldError(
|
||||
"Can't use target_field for multicolumn relations."
|
||||
)
|
||||
return target_fields[0]
|
||||
|
||||
@cached_property
|
||||
def related_model(self):
|
||||
if not self.field.model:
|
||||
raise AttributeError(
|
||||
"This property can't be accessed before self.field.contribute_to_class "
|
||||
"has been called."
|
||||
)
|
||||
return self.field.model
|
||||
|
||||
@cached_property
|
||||
def many_to_many(self):
|
||||
return self.field.many_to_many
|
||||
|
||||
@cached_property
|
||||
def many_to_one(self):
|
||||
return self.field.one_to_many
|
||||
|
||||
@cached_property
|
||||
def one_to_many(self):
|
||||
return self.field.many_to_one
|
||||
|
||||
@cached_property
|
||||
def one_to_one(self):
|
||||
return self.field.one_to_one
|
||||
|
||||
def get_lookup(self, lookup_name):
|
||||
return self.field.get_lookup(lookup_name)
|
||||
|
||||
def get_lookups(self):
|
||||
return self.field.get_lookups()
|
||||
|
||||
def get_transform(self, name):
|
||||
return self.field.get_transform(name)
|
||||
|
||||
def get_internal_type(self):
|
||||
return self.field.get_internal_type()
|
||||
|
||||
@property
|
||||
def db_type(self):
|
||||
return self.field.db_type
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: %s.%s>" % (
|
||||
type(self).__name__,
|
||||
self.related_model._meta.app_label,
|
||||
self.related_model._meta.model_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return (
|
||||
self.field,
|
||||
self.model,
|
||||
self.related_name,
|
||||
self.related_query_name,
|
||||
make_hashable(self.limit_choices_to),
|
||||
self.parent_link,
|
||||
self.on_delete,
|
||||
self.symmetrical,
|
||||
self.multiple,
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identity)
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
# Delete the path_infos cached property because it can be recalculated
|
||||
# at first invocation after deserialization. The attribute must be
|
||||
# removed because subclasses like ManyToOneRel may have a PathInfo
|
||||
# which contains an intermediate M2M table that's been dynamically
|
||||
# created and doesn't exist in the .models module.
|
||||
# This is a reverse relation, so there is no reverse_path_infos to
|
||||
# delete.
|
||||
state.pop("path_infos", None)
|
||||
return state
|
||||
|
||||
def get_choices(
|
||||
self,
|
||||
include_blank=True,
|
||||
blank_choice=BLANK_CHOICE_DASH,
|
||||
limit_choices_to=None,
|
||||
ordering=(),
|
||||
):
|
||||
"""
|
||||
Return choices with a default blank choices included, for use
|
||||
as <select> choices for this field.
|
||||
|
||||
Analog of django.db.models.fields.Field.get_choices(), provided
|
||||
initially for utilization by RelatedFieldListFilter.
|
||||
"""
|
||||
limit_choices_to = limit_choices_to or self.limit_choices_to
|
||||
qs = self.related_model._default_manager.complex_filter(limit_choices_to)
|
||||
if ordering:
|
||||
qs = qs.order_by(*ordering)
|
||||
return (blank_choice if include_blank else []) + [(x.pk, str(x)) for x in qs]
|
||||
|
||||
def get_joining_columns(self):
|
||||
warnings.warn(
|
||||
"ForeignObjectRel.get_joining_columns() is deprecated. Use "
|
||||
"get_joining_fields() instead.",
|
||||
RemovedInDjango60Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.field.get_reverse_joining_columns()
|
||||
|
||||
def get_joining_fields(self):
|
||||
return self.field.get_reverse_joining_fields()
|
||||
|
||||
def get_extra_restriction(self, alias, related_alias):
|
||||
return self.field.get_extra_restriction(related_alias, alias)
|
||||
|
||||
def set_field_name(self):
|
||||
"""
|
||||
Set the related field's name, this is not available until later stages
|
||||
of app loading, so set_field_name is called from
|
||||
set_attributes_from_rel()
|
||||
"""
|
||||
# By default foreign object doesn't relate to any remote field (for
|
||||
# example custom multicolumn joins currently have no remote field).
|
||||
self.field_name = None
|
||||
|
||||
@cached_property
|
||||
def accessor_name(self):
|
||||
return self.get_accessor_name()
|
||||
|
||||
def get_accessor_name(self, model=None):
|
||||
# This method encapsulates the logic that decides what name to give an
|
||||
# accessor descriptor that retrieves related many-to-one or
|
||||
# many-to-many objects. It uses the lowercased object_name + "_set",
|
||||
# but this can be overridden with the "related_name" option. Due to
|
||||
# backwards compatibility ModelForms need to be able to provide an
|
||||
# alternate model. See BaseInlineFormSet.get_default_prefix().
|
||||
opts = model._meta if model else self.related_model._meta
|
||||
model = model or self.related_model
|
||||
if self.multiple:
|
||||
# If this is a symmetrical m2m relation on self, there is no
|
||||
# reverse accessor.
|
||||
if self.symmetrical and model == self.model:
|
||||
return None
|
||||
if self.related_name:
|
||||
return self.related_name
|
||||
return opts.model_name + ("_set" if self.multiple else "")
|
||||
|
||||
def get_path_info(self, filtered_relation=None):
|
||||
if filtered_relation:
|
||||
return self.field.get_reverse_path_info(filtered_relation)
|
||||
else:
|
||||
return self.field.reverse_path_infos
|
||||
|
||||
@cached_property
|
||||
def path_infos(self):
|
||||
return self.get_path_info()
|
||||
|
||||
@cached_property
|
||||
def cache_name(self):
|
||||
"""
|
||||
Return the name of the cache key to use for storing an instance of the
|
||||
forward model on the reverse model.
|
||||
"""
|
||||
return self.accessor_name
|
||||
|
||||
|
||||
class ManyToOneRel(ForeignObjectRel):
|
||||
"""
|
||||
Used by the ForeignKey field to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
|
||||
Note: Because we somewhat abuse the Rel objects by using them as reverse
|
||||
fields we get the funny situation where
|
||||
``ManyToOneRel.many_to_one == False`` and
|
||||
``ManyToOneRel.one_to_many == True``. This is unfortunate but the actual
|
||||
ManyToOneRel class is a private API and there is work underway to turn
|
||||
reverse relations into actual fields.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
parent_link=parent_link,
|
||||
on_delete=on_delete,
|
||||
)
|
||||
|
||||
self.field_name = field_name
|
||||
|
||||
def __getstate__(self):
|
||||
state = super().__getstate__()
|
||||
state.pop("related_model", None)
|
||||
return state
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return super().identity + (self.field_name,)
|
||||
|
||||
def get_related_field(self):
|
||||
"""
|
||||
Return the Field in the 'to' object to which this relationship is tied.
|
||||
"""
|
||||
field = self.model._meta.get_field(self.field_name)
|
||||
if not field.concrete:
|
||||
raise exceptions.FieldDoesNotExist(
|
||||
"No related field named '%s'" % self.field_name
|
||||
)
|
||||
return field
|
||||
|
||||
def set_field_name(self):
|
||||
self.field_name = self.field_name or self.model._meta.pk.name
|
||||
|
||||
|
||||
class OneToOneRel(ManyToOneRel):
|
||||
"""
|
||||
Used by OneToOneField to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
parent_link=parent_link,
|
||||
on_delete=on_delete,
|
||||
)
|
||||
|
||||
self.multiple = False
|
||||
|
||||
|
||||
class ManyToManyRel(ForeignObjectRel):
|
||||
"""
|
||||
Used by ManyToManyField to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
symmetrical=True,
|
||||
through=None,
|
||||
through_fields=None,
|
||||
db_constraint=True,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
)
|
||||
|
||||
if through and not db_constraint:
|
||||
raise ValueError("Can't supply a through model and db_constraint=False")
|
||||
self.through = through
|
||||
|
||||
if through_fields and not through:
|
||||
raise ValueError("Cannot specify through_fields without a through model")
|
||||
self.through_fields = through_fields
|
||||
|
||||
self.symmetrical = symmetrical
|
||||
self.db_constraint = db_constraint
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return super().identity + (
|
||||
self.through,
|
||||
make_hashable(self.through_fields),
|
||||
self.db_constraint,
|
||||
)
|
||||
|
||||
def get_related_field(self):
|
||||
"""
|
||||
Return the field in the 'to' object to which this relationship is tied.
|
||||
Provided for symmetry with ManyToOneRel.
|
||||
"""
|
||||
opts = self.through._meta
|
||||
if self.through_fields:
|
||||
field = opts.get_field(self.through_fields[0])
|
||||
else:
|
||||
for field in opts.fields:
|
||||
rel = getattr(field, "remote_field", None)
|
||||
if rel and rel.model == self.model:
|
||||
break
|
||||
return field.foreign_related_fields[0]
|
||||
|
|
@ -0,0 +1,359 @@
|
|||
import itertools
|
||||
|
||||
from django.core.exceptions import EmptyResultSet
|
||||
from django.db.models import Field
|
||||
from django.db.models.expressions import (
|
||||
ColPairs,
|
||||
Func,
|
||||
ResolvedOuterRef,
|
||||
Subquery,
|
||||
Value,
|
||||
)
|
||||
from django.db.models.lookups import (
|
||||
Exact,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
In,
|
||||
IsNull,
|
||||
LessThan,
|
||||
LessThanOrEqual,
|
||||
)
|
||||
from django.db.models.sql import Query
|
||||
from django.db.models.sql.where import AND, OR, WhereNode
|
||||
|
||||
|
||||
class Tuple(Func):
|
||||
allows_composite_expressions = True
|
||||
function = ""
|
||||
output_field = Field()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.source_expressions)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.source_expressions)
|
||||
|
||||
|
||||
class TupleLookupMixin:
|
||||
allows_composite_expressions = True
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if self.rhs_is_direct_value():
|
||||
self.check_rhs_is_tuple_or_list()
|
||||
self.check_rhs_length_equals_lhs_length()
|
||||
else:
|
||||
self.check_rhs_is_supported_expression()
|
||||
super().get_prep_lookup()
|
||||
return self.rhs
|
||||
|
||||
def check_rhs_is_tuple_or_list(self):
|
||||
if not isinstance(self.rhs, (tuple, list)):
|
||||
lhs_str = self.get_lhs_str()
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} lookup of {lhs_str} must be a tuple or a list"
|
||||
)
|
||||
|
||||
def check_rhs_length_equals_lhs_length(self):
|
||||
len_lhs = len(self.lhs)
|
||||
if len_lhs != len(self.rhs):
|
||||
lhs_str = self.get_lhs_str()
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements"
|
||||
)
|
||||
|
||||
def check_rhs_is_supported_expression(self):
|
||||
if not isinstance(self.rhs, (ResolvedOuterRef, Query)):
|
||||
lhs_str = self.get_lhs_str()
|
||||
rhs_cls = self.rhs.__class__.__name__
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
|
||||
f"only supports OuterRef and QuerySet objects (received {rhs_cls!r})"
|
||||
)
|
||||
|
||||
def get_lhs_str(self):
|
||||
if isinstance(self.lhs, ColPairs):
|
||||
return repr(self.lhs.field.name)
|
||||
else:
|
||||
names = ", ".join(repr(f.name) for f in self.lhs)
|
||||
return f"({names})"
|
||||
|
||||
def get_prep_lhs(self):
|
||||
if isinstance(self.lhs, (tuple, list)):
|
||||
return Tuple(*self.lhs)
|
||||
return super().get_prep_lhs()
|
||||
|
||||
def process_lhs(self, compiler, connection, lhs=None):
|
||||
sql, params = super().process_lhs(compiler, connection, lhs)
|
||||
if not isinstance(self.lhs, Tuple):
|
||||
sql = f"({sql})"
|
||||
return sql, params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
if self.rhs_is_direct_value():
|
||||
args = [
|
||||
Value(val, output_field=col.output_field)
|
||||
for col, val in zip(self.lhs, self.rhs)
|
||||
]
|
||||
return compiler.compile(Tuple(*args))
|
||||
else:
|
||||
sql, params = compiler.compile(self.rhs)
|
||||
if isinstance(self.rhs, ColPairs):
|
||||
return "(%s)" % sql, params
|
||||
elif isinstance(self.rhs, Query):
|
||||
return super().process_rhs(compiler, connection)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Composite field lookups only work with composite expressions."
|
||||
)
|
||||
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__}.get_fallback_sql() must be implemented "
|
||||
f"for backends that don't have the supports_tuple_lookups feature enabled."
|
||||
)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_tuple_lookups:
|
||||
return self.get_fallback_sql(compiler, connection)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleExact(TupleLookupMixin, Exact):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
if isinstance(self.rhs, Query):
|
||||
return super(TupleLookupMixin, self).as_sql(compiler, connection)
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) == (x, y, z) as SQL:
|
||||
# WHERE a = x AND b = y AND c = z
|
||||
lookups = [Exact(col, val) for col, val in zip(self.lhs, self.rhs)]
|
||||
root = WhereNode(lookups, connector=AND)
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleIsNull(TupleLookupMixin, IsNull):
|
||||
def get_prep_lookup(self):
|
||||
rhs = self.rhs
|
||||
if isinstance(rhs, (tuple, list)) and len(rhs) == 1:
|
||||
rhs = rhs[0]
|
||||
if isinstance(rhs, bool):
|
||||
return rhs
|
||||
raise ValueError(
|
||||
"The QuerySet value for an isnull lookup must be True or False."
|
||||
)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# e.g.: (a, b, c) is None as SQL:
|
||||
# WHERE a IS NULL OR b IS NULL OR c IS NULL
|
||||
# e.g.: (a, b, c) is not None as SQL:
|
||||
# WHERE a IS NOT NULL AND b IS NOT NULL AND c IS NOT NULL
|
||||
rhs = self.rhs
|
||||
lookups = [IsNull(col, rhs) for col in self.lhs]
|
||||
root = WhereNode(lookups, connector=OR if rhs else AND)
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleGreaterThan(TupleLookupMixin, GreaterThan):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) > (x, y, z) as SQL:
|
||||
# WHERE a > x OR (a = x AND (b > y OR (b = y AND c > z)))
|
||||
lookups = itertools.cycle([GreaterThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in self.lhs for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list[:-1])
|
||||
vals_iter = iter(vals_list[:-1])
|
||||
col = next(cols_iter)
|
||||
val = next(vals_iter)
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) >= (x, y, z) as SQL:
|
||||
# WHERE a > x OR (a = x AND (b > y OR (b = y AND (c > z OR c = z))))
|
||||
lookups = itertools.cycle([GreaterThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in self.lhs for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list)
|
||||
vals_iter = iter(vals_list)
|
||||
col = next(cols_iter)
|
||||
val = next(vals_iter)
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleLessThan(TupleLookupMixin, LessThan):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) < (x, y, z) as SQL:
|
||||
# WHERE a < x OR (a = x AND (b < y OR (b = y AND c < z)))
|
||||
lookups = itertools.cycle([LessThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in self.lhs for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list[:-1])
|
||||
vals_iter = iter(vals_list[:-1])
|
||||
col = next(cols_iter)
|
||||
val = next(vals_iter)
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) <= (x, y, z) as SQL:
|
||||
# WHERE a < x OR (a = x AND (b < y OR (b = y AND (c < z OR c = z))))
|
||||
lookups = itertools.cycle([LessThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in self.lhs for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list)
|
||||
vals_iter = iter(vals_list)
|
||||
col = next(cols_iter)
|
||||
val = next(vals_iter)
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleIn(TupleLookupMixin, In):
|
||||
def get_prep_lookup(self):
|
||||
if self.rhs_is_direct_value():
|
||||
self.check_rhs_is_tuple_or_list()
|
||||
self.check_rhs_is_collection_of_tuples_or_lists()
|
||||
self.check_rhs_elements_length_equals_lhs_length()
|
||||
else:
|
||||
self.check_rhs_is_query()
|
||||
super(TupleLookupMixin, self).get_prep_lookup()
|
||||
|
||||
return self.rhs # skip checks from mixin
|
||||
|
||||
def check_rhs_is_collection_of_tuples_or_lists(self):
|
||||
if not all(isinstance(vals, (tuple, list)) for vals in self.rhs):
|
||||
lhs_str = self.get_lhs_str()
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} lookup of {lhs_str} "
|
||||
"must be a collection of tuples or lists"
|
||||
)
|
||||
|
||||
def check_rhs_elements_length_equals_lhs_length(self):
|
||||
len_lhs = len(self.lhs)
|
||||
if not all(len_lhs == len(vals) for vals in self.rhs):
|
||||
lhs_str = self.get_lhs_str()
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} lookup of {lhs_str} "
|
||||
f"must have {len_lhs} elements each"
|
||||
)
|
||||
|
||||
def check_rhs_is_query(self):
|
||||
if not isinstance(self.rhs, (Query, Subquery)):
|
||||
lhs_str = self.get_lhs_str()
|
||||
rhs_cls = self.rhs.__class__.__name__
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
|
||||
f"must be a Query object (received {rhs_cls!r})"
|
||||
)
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
if not self.rhs_is_direct_value():
|
||||
return super(TupleLookupMixin, self).process_rhs(compiler, connection)
|
||||
|
||||
rhs = self.rhs
|
||||
if not rhs:
|
||||
raise EmptyResultSet
|
||||
|
||||
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
|
||||
# WHERE (a, b, c) IN ((x1, y1, z1), (x2, y2, z2))
|
||||
result = []
|
||||
lhs = self.lhs
|
||||
|
||||
for vals in rhs:
|
||||
result.append(
|
||||
Tuple(
|
||||
*[
|
||||
Value(val, output_field=col.output_field)
|
||||
for col, val in zip(lhs, vals)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return compiler.compile(Tuple(*result))
|
||||
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
rhs = self.rhs
|
||||
if not rhs:
|
||||
raise EmptyResultSet
|
||||
if not self.rhs_is_direct_value():
|
||||
return super(TupleLookupMixin, self).as_sql(compiler, connection)
|
||||
|
||||
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
|
||||
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
|
||||
root = WhereNode([], connector=OR)
|
||||
lhs = self.lhs
|
||||
|
||||
for vals in rhs:
|
||||
lookups = [Exact(col, val) for col, val in zip(lhs, vals)]
|
||||
root.children.append(WhereNode(lookups, connector=AND))
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
tuple_lookups = {
|
||||
"exact": TupleExact,
|
||||
"gt": TupleGreaterThan,
|
||||
"gte": TupleGreaterThanOrEqual,
|
||||
"lt": TupleLessThan,
|
||||
"lte": TupleLessThanOrEqual,
|
||||
"in": TupleIn,
|
||||
"isnull": TupleIsNull,
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue