start-pack

This commit is contained in:
bdrtr 2025-04-28 15:42:23 +03:00
commit 3e1fa59b3d
5723 changed files with 757971 additions and 0 deletions

View file

@ -0,0 +1,611 @@
"""
PostgreSQL database backend for Django.
Requires psycopg2 >= 2.8.4 or psycopg >= 3.1.8
"""
import asyncio
import threading
import warnings
from contextlib import contextmanager
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import DatabaseError as WrappedDatabaseError
from django.db import connections
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
from django.utils.safestring import SafeString
from django.utils.version import get_version_tuple
try:
try:
import psycopg as Database
except ImportError:
import psycopg2 as Database
except ImportError:
raise ImproperlyConfigured("Error loading psycopg2 or psycopg module")
def psycopg_version():
version = Database.__version__.split(" ", 1)[0]
return get_version_tuple(version)
if psycopg_version() < (2, 8, 4):
raise ImproperlyConfigured(
f"psycopg2 version 2.8.4 or newer is required; you have {Database.__version__}"
)
if (3,) <= psycopg_version() < (3, 1, 8):
raise ImproperlyConfigured(
f"psycopg version 3.1.8 or newer is required; you have {Database.__version__}"
)
from .psycopg_any import IsolationLevel, is_psycopg3 # NOQA isort:skip
if is_psycopg3:
from psycopg import adapters, sql
from psycopg.pq import Format
from .psycopg_any import get_adapters_template, register_tzloader
TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
else:
import psycopg2.extensions
import psycopg2.extras
psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
psycopg2.extras.register_uuid()
# Register support for inet[] manually so we don't have to handle the Inet()
# object on load all the time.
INETARRAY_OID = 1041
INETARRAY = psycopg2.extensions.new_array_type(
(INETARRAY_OID,),
"INETARRAY",
psycopg2.extensions.UNICODE,
)
psycopg2.extensions.register_type(INETARRAY)
# Some of these import psycopg, so import them after checking if it's installed.
from .client import DatabaseClient # NOQA isort:skip
from .creation import DatabaseCreation # NOQA isort:skip
from .features import DatabaseFeatures # NOQA isort:skip
from .introspection import DatabaseIntrospection # NOQA isort:skip
from .operations import DatabaseOperations # NOQA isort:skip
from .schema import DatabaseSchemaEditor # NOQA isort:skip
def _get_varchar_column(data):
if data["max_length"] is None:
return "varchar"
return "varchar(%(max_length)s)" % data
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = "postgresql"
display_name = "PostgreSQL"
# This dictionary maps Field objects to their associated PostgreSQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
"AutoField": "integer",
"BigAutoField": "bigint",
"BinaryField": "bytea",
"BooleanField": "boolean",
"CharField": _get_varchar_column,
"DateField": "date",
"DateTimeField": "timestamp with time zone",
"DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
"DurationField": "interval",
"FileField": "varchar(%(max_length)s)",
"FilePathField": "varchar(%(max_length)s)",
"FloatField": "double precision",
"IntegerField": "integer",
"BigIntegerField": "bigint",
"IPAddressField": "inet",
"GenericIPAddressField": "inet",
"JSONField": "jsonb",
"OneToOneField": "integer",
"PositiveBigIntegerField": "bigint",
"PositiveIntegerField": "integer",
"PositiveSmallIntegerField": "smallint",
"SlugField": "varchar(%(max_length)s)",
"SmallAutoField": "smallint",
"SmallIntegerField": "smallint",
"TextField": "text",
"TimeField": "time",
"UUIDField": "uuid",
}
data_type_check_constraints = {
"PositiveBigIntegerField": '"%(column)s" >= 0',
"PositiveIntegerField": '"%(column)s" >= 0',
"PositiveSmallIntegerField": '"%(column)s" >= 0',
}
data_types_suffix = {
"AutoField": "GENERATED BY DEFAULT AS IDENTITY",
"BigAutoField": "GENERATED BY DEFAULT AS IDENTITY",
"SmallAutoField": "GENERATED BY DEFAULT AS IDENTITY",
}
operators = {
"exact": "= %s",
"iexact": "= UPPER(%s)",
"contains": "LIKE %s",
"icontains": "LIKE UPPER(%s)",
"regex": "~ %s",
"iregex": "~* %s",
"gt": "> %s",
"gte": ">= %s",
"lt": "< %s",
"lte": "<= %s",
"startswith": "LIKE %s",
"endswith": "LIKE %s",
"istartswith": "LIKE UPPER(%s)",
"iendswith": "LIKE UPPER(%s)",
}
# The patterns below are used to generate SQL pattern lookup clauses when
# the right-hand side of the lookup isn't a raw string (it might be an expression
# or the result of a bilateral transformation).
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
# escaped on database side.
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
pattern_esc = (
r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
)
pattern_ops = {
"contains": "LIKE '%%' || {} || '%%'",
"icontains": "LIKE '%%' || UPPER({}) || '%%'",
"startswith": "LIKE {} || '%%'",
"istartswith": "LIKE UPPER({}) || '%%'",
"endswith": "LIKE '%%' || {}",
"iendswith": "LIKE '%%' || UPPER({})",
}
Database = Database
SchemaEditorClass = DatabaseSchemaEditor
# Classes instantiated in __init__().
client_class = DatabaseClient
creation_class = DatabaseCreation
features_class = DatabaseFeatures
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
# PostgreSQL backend-specific attributes.
_named_cursor_idx = 0
_connection_pools = {}
@property
def pool(self):
pool_options = self.settings_dict["OPTIONS"].get("pool")
if self.alias == NO_DB_ALIAS or not pool_options:
return None
if self.alias not in self._connection_pools:
if self.settings_dict.get("CONN_MAX_AGE", 0) != 0:
raise ImproperlyConfigured(
"Pooling doesn't support persistent connections."
)
# Set the default options.
if pool_options is True:
pool_options = {}
try:
from psycopg_pool import ConnectionPool
except ImportError as err:
raise ImproperlyConfigured(
"Error loading psycopg_pool module.\nDid you install psycopg[pool]?"
) from err
connect_kwargs = self.get_connection_params()
# Ensure we run in autocommit, Django properly sets it later on.
connect_kwargs["autocommit"] = True
enable_checks = self.settings_dict["CONN_HEALTH_CHECKS"]
pool = ConnectionPool(
kwargs=connect_kwargs,
open=False, # Do not open the pool during startup.
configure=self._configure_connection,
check=ConnectionPool.check_connection if enable_checks else None,
**pool_options,
)
# setdefault() ensures that multiple threads don't set this in
# parallel. Since we do not open the pool during it's init above,
# this means that at worst during startup multiple threads generate
# pool objects and the first to set it wins.
self._connection_pools.setdefault(self.alias, pool)
return self._connection_pools[self.alias]
def close_pool(self):
if self.pool:
self.pool.close()
del self._connection_pools[self.alias]
def get_database_version(self):
"""
Return a tuple of the database's version.
E.g. for pg_version 120004, return (12, 4).
"""
return divmod(self.pg_version, 10000)
def get_connection_params(self):
settings_dict = self.settings_dict
# None may be used to connect to the default 'postgres' db
if settings_dict["NAME"] == "" and not settings_dict["OPTIONS"].get("service"):
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the NAME or OPTIONS['service'] value."
)
if len(settings_dict["NAME"] or "") > self.ops.max_name_length():
raise ImproperlyConfigured(
"The database name '%s' (%d characters) is longer than "
"PostgreSQL's limit of %d characters. Supply a shorter NAME "
"in settings.DATABASES."
% (
settings_dict["NAME"],
len(settings_dict["NAME"]),
self.ops.max_name_length(),
)
)
if settings_dict["NAME"]:
conn_params = {
"dbname": settings_dict["NAME"],
**settings_dict["OPTIONS"],
}
elif settings_dict["NAME"] is None:
# Connect to the default 'postgres' db.
settings_dict["OPTIONS"].pop("service", None)
conn_params = {"dbname": "postgres", **settings_dict["OPTIONS"]}
else:
conn_params = {**settings_dict["OPTIONS"]}
conn_params["client_encoding"] = "UTF8"
conn_params.pop("assume_role", None)
conn_params.pop("isolation_level", None)
pool_options = conn_params.pop("pool", None)
if pool_options and not is_psycopg3:
raise ImproperlyConfigured("Database pooling requires psycopg >= 3")
server_side_binding = conn_params.pop("server_side_binding", None)
conn_params.setdefault(
"cursor_factory",
(
ServerBindingCursor
if is_psycopg3 and server_side_binding is True
else Cursor
),
)
if settings_dict["USER"]:
conn_params["user"] = settings_dict["USER"]
if settings_dict["PASSWORD"]:
conn_params["password"] = settings_dict["PASSWORD"]
if settings_dict["HOST"]:
conn_params["host"] = settings_dict["HOST"]
if settings_dict["PORT"]:
conn_params["port"] = settings_dict["PORT"]
if is_psycopg3:
conn_params["context"] = get_adapters_template(
settings.USE_TZ, self.timezone
)
# Disable prepared statements by default to keep connection poolers
# working. Can be reenabled via OPTIONS in the settings dict.
conn_params["prepare_threshold"] = conn_params.pop(
"prepare_threshold", None
)
return conn_params
@async_unsafe
def get_new_connection(self, conn_params):
# self.isolation_level must be set:
# - after connecting to the database in order to obtain the database's
# default when no value is explicitly specified in options.
# - before calling _set_autocommit() because if autocommit is on, that
# will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
options = self.settings_dict["OPTIONS"]
set_isolation_level = False
try:
isolation_level_value = options["isolation_level"]
except KeyError:
self.isolation_level = IsolationLevel.READ_COMMITTED
else:
# Set the isolation level to the value from OPTIONS.
try:
self.isolation_level = IsolationLevel(isolation_level_value)
set_isolation_level = True
except ValueError:
raise ImproperlyConfigured(
f"Invalid transaction isolation level {isolation_level_value} "
f"specified. Use one of the psycopg.IsolationLevel values."
)
if self.pool:
# If nothing else has opened the pool, open it now.
self.pool.open()
connection = self.pool.getconn()
else:
connection = self.Database.connect(**conn_params)
if set_isolation_level:
connection.isolation_level = self.isolation_level
if not is_psycopg3:
# Register dummy loads() to avoid a round trip from psycopg2's
# decode to json.dumps() to json.loads(), when using a custom
# decoder in JSONField.
psycopg2.extras.register_default_jsonb(
conn_or_curs=connection, loads=lambda x: x
)
return connection
def ensure_timezone(self):
# Close the pool so new connections pick up the correct timezone.
self.close_pool()
if self.connection is None:
return False
return self._configure_timezone(self.connection)
def _configure_timezone(self, connection):
conn_timezone_name = connection.info.parameter_status("TimeZone")
timezone_name = self.timezone_name
if timezone_name and conn_timezone_name != timezone_name:
with connection.cursor() as cursor:
cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
return True
return False
def _configure_role(self, connection):
if new_role := self.settings_dict["OPTIONS"].get("assume_role"):
with connection.cursor() as cursor:
sql = self.ops.compose_sql("SET ROLE %s", [new_role])
cursor.execute(sql)
return True
return False
def _configure_connection(self, connection):
# This function is called from init_connection_state and from the
# psycopg pool itself after a connection is opened.
# Commit after setting the time zone.
commit_tz = self._configure_timezone(connection)
# Set the role on the connection. This is useful if the credential used
# to login is not the same as the role that owns database resources. As
# can be the case when using temporary or ephemeral credentials.
commit_role = self._configure_role(connection)
return commit_role or commit_tz
def _close(self):
if self.connection is not None:
# `wrap_database_errors` only works for `putconn` as long as there
# is no `reset` function set in the pool because it is deferred
# into a thread and not directly executed.
with self.wrap_database_errors:
if self.pool:
# Ensure the correct pool is returned. This is a workaround
# for tests so a pool can be changed on setting changes
# (e.g. USE_TZ, TIME_ZONE).
self.connection._pool.putconn(self.connection)
# Connection can no longer be used.
self.connection = None
else:
return self.connection.close()
def init_connection_state(self):
super().init_connection_state()
if self.connection is not None and not self.pool:
commit = self._configure_connection(self.connection)
if commit and not self.get_autocommit():
self.connection.commit()
@async_unsafe
def create_cursor(self, name=None):
if name:
if is_psycopg3 and (
self.settings_dict["OPTIONS"].get("server_side_binding") is not True
):
# psycopg >= 3 forces the usage of server-side bindings for
# named cursors so a specialized class that implements
# server-side cursors while performing client-side bindings
# must be used if `server_side_binding` is disabled (default).
cursor = ServerSideCursor(
self.connection,
name=name,
scrollable=False,
withhold=self.connection.autocommit,
)
else:
# In autocommit mode, the cursor will be used outside of a
# transaction, hence use a holdable cursor.
cursor = self.connection.cursor(
name, scrollable=False, withhold=self.connection.autocommit
)
else:
cursor = self.connection.cursor()
if is_psycopg3:
# Register the cursor timezone only if the connection disagrees, to
# avoid copying the adapter map.
tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
if self.timezone != tzloader.timezone:
register_tzloader(self.timezone, cursor)
else:
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
return cursor
def tzinfo_factory(self, offset):
return self.timezone
@async_unsafe
def chunked_cursor(self):
self._named_cursor_idx += 1
# Get the current async task
# Note that right now this is behind @async_unsafe, so this is
# unreachable, but in future we'll start loosening this restriction.
# For now, it's here so that every use of "threading" is
# also async-compatible.
try:
current_task = asyncio.current_task()
except RuntimeError:
current_task = None
# Current task can be none even if the current_task call didn't error
if current_task:
task_ident = str(id(current_task))
else:
task_ident = "sync"
# Use that and the thread ident to get a unique name
return self._cursor(
name="_django_curs_%d_%s_%d"
% (
# Avoid reusing name in other threads / tasks
threading.current_thread().ident,
task_ident,
self._named_cursor_idx,
)
)
def _set_autocommit(self, autocommit):
with self.wrap_database_errors:
self.connection.autocommit = autocommit
def check_constraints(self, table_names=None):
"""
Check constraints by setting them to immediate. Return them to deferred
afterward.
"""
with self.cursor() as cursor:
cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
cursor.execute("SET CONSTRAINTS ALL DEFERRED")
def is_usable(self):
if self.connection is None:
return False
try:
# Use a psycopg cursor directly, bypassing Django's utilities.
with self.connection.cursor() as cursor:
cursor.execute("SELECT 1")
except Database.Error:
return False
else:
return True
def close_if_health_check_failed(self):
if self.pool:
# The pool only returns healthy connections.
return
return super().close_if_health_check_failed()
@contextmanager
def _nodb_cursor(self):
cursor = None
try:
with super()._nodb_cursor() as cursor:
yield cursor
except (Database.DatabaseError, WrappedDatabaseError):
if cursor is not None:
raise
warnings.warn(
"Normally Django will use a connection to the 'postgres' database "
"to avoid running initialization queries against the production "
"database when it's not needed (for example, when running tests). "
"Django was unable to create a connection to the 'postgres' database "
"and will use the first PostgreSQL database instead.",
RuntimeWarning,
)
for connection in connections.all():
if (
connection.vendor == "postgresql"
and connection.settings_dict["NAME"] != "postgres"
):
conn = self.__class__(
{
**self.settings_dict,
"NAME": connection.settings_dict["NAME"],
},
alias=self.alias,
)
try:
with conn.cursor() as cursor:
yield cursor
finally:
conn.close()
break
else:
raise
@cached_property
def pg_version(self):
with self.temporary_connection():
return self.connection.info.server_version
def make_debug_cursor(self, cursor):
return CursorDebugWrapper(cursor, self)
if is_psycopg3:
class CursorMixin:
"""
A subclass of psycopg cursor implementing callproc.
"""
def callproc(self, name, args=None):
if not isinstance(name, sql.Identifier):
name = sql.Identifier(name)
qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")]
if args:
for item in args:
qparts.append(sql.Literal(item))
qparts.append(sql.SQL(","))
del qparts[-1]
qparts.append(sql.SQL(")"))
stmt = sql.Composed(qparts)
self.execute(stmt)
return args
class ServerBindingCursor(CursorMixin, Database.Cursor):
pass
class Cursor(CursorMixin, Database.ClientCursor):
pass
class ServerSideCursor(
CursorMixin, Database.client_cursor.ClientCursorMixin, Database.ServerCursor
):
"""
psycopg >= 3 forces the usage of server-side bindings when using named
cursors but the ORM doesn't yet support the systematic generation of
prepareable SQL (#20516).
ClientCursorMixin forces the usage of client-side bindings while
ServerCursor implements the logic required to declare and scroll
through named cursors.
Mixing ClientCursorMixin in wouldn't be necessary if Cursor allowed to
specify how parameters should be bound instead, which ServerCursor
would inherit, but that's not the case.
"""
class CursorDebugWrapper(BaseCursorDebugWrapper):
def copy(self, statement):
with self.debug_sql(statement):
return self.cursor.copy(statement)
else:
Cursor = psycopg2.extensions.cursor
class CursorDebugWrapper(BaseCursorDebugWrapper):
def copy_expert(self, sql, file, *args):
with self.debug_sql(sql):
return self.cursor.copy_expert(sql, file, *args)
def copy_to(self, file, table, *args, **kwargs):
with self.debug_sql(sql="COPY %s TO STDOUT" % table):
return self.cursor.copy_to(file, table, *args, **kwargs)

View file

@ -0,0 +1,64 @@
import signal
from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = "psql"
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name]
options = settings_dict["OPTIONS"]
host = settings_dict.get("HOST")
port = settings_dict.get("PORT")
dbname = settings_dict.get("NAME")
user = settings_dict.get("USER")
passwd = settings_dict.get("PASSWORD")
passfile = options.get("passfile")
service = options.get("service")
sslmode = options.get("sslmode")
sslrootcert = options.get("sslrootcert")
sslcert = options.get("sslcert")
sslkey = options.get("sslkey")
if not dbname and not service:
# Connect to the default 'postgres' db.
dbname = "postgres"
if user:
args += ["-U", user]
if host:
args += ["-h", host]
if port:
args += ["-p", str(port)]
args.extend(parameters)
if dbname:
args += [dbname]
env = {}
if passwd:
env["PGPASSWORD"] = str(passwd)
if service:
env["PGSERVICE"] = str(service)
if sslmode:
env["PGSSLMODE"] = str(sslmode)
if sslrootcert:
env["PGSSLROOTCERT"] = str(sslrootcert)
if sslcert:
env["PGSSLCERT"] = str(sslcert)
if sslkey:
env["PGSSLKEY"] = str(sslkey)
if passfile:
env["PGPASSFILE"] = str(passfile)
return args, (env or None)
def runshell(self, parameters):
sigint_handler = signal.getsignal(signal.SIGINT)
try:
# Allow SIGINT to pass to psql to abort queries.
signal.signal(signal.SIGINT, signal.SIG_IGN)
super().runshell(parameters)
finally:
# Restore the original SIGINT handler.
signal.signal(signal.SIGINT, sigint_handler)

View file

@ -0,0 +1,50 @@
from django.db.models.sql.compiler import (
SQLAggregateCompiler,
SQLCompiler,
SQLDeleteCompiler,
)
from django.db.models.sql.compiler import SQLInsertCompiler as BaseSQLInsertCompiler
from django.db.models.sql.compiler import SQLUpdateCompiler
__all__ = [
"SQLAggregateCompiler",
"SQLCompiler",
"SQLDeleteCompiler",
"SQLInsertCompiler",
"SQLUpdateCompiler",
]
class InsertUnnest(list):
"""
Sentinel value to signal DatabaseOperations.bulk_insert_sql() that the
UNNEST strategy should be used for the bulk insert.
"""
def __str__(self):
return "UNNEST(%s)" % ", ".join(self)
class SQLInsertCompiler(BaseSQLInsertCompiler):
def assemble_as_sql(self, fields, value_rows):
# Specialize bulk-insertion of literal non-array values through
# UNNEST to reduce the time spent planning the query.
if (
# The optimization is not worth doing if there is a single
# row as it will result in the same number of placeholders.
len(value_rows) <= 1
# Lack of fields denote the usage of the DEFAULT keyword
# for the insertion of empty rows.
or any(field is None for field in fields)
# Compilable cannot be combined in an array of literal values.
or any(any(hasattr(value, "as_sql") for value in row) for row in value_rows)
):
return super().assemble_as_sql(fields, value_rows)
db_types = [field.db_type(self.connection) for field in fields]
# Abort if any of the fields are arrays as UNNEST indiscriminately
# flatten them instead of reducing their nesting by one.
if any(db_type.endswith("]") for db_type in db_types):
return super().assemble_as_sql(fields, value_rows)
return InsertUnnest(["(%%s)::%s[]" % db_type for db_type in db_types]), [
list(map(list, zip(*value_rows)))
]

View file

@ -0,0 +1,91 @@
import sys
from django.core.exceptions import ImproperlyConfigured
from django.db.backends.base.creation import BaseDatabaseCreation
from django.db.backends.postgresql.psycopg_any import errors
from django.db.backends.utils import strip_quotes
class DatabaseCreation(BaseDatabaseCreation):
def _quote_name(self, name):
return self.connection.ops.quote_name(name)
def _get_database_create_suffix(self, encoding=None, template=None):
suffix = ""
if encoding:
suffix += " ENCODING '{}'".format(encoding)
if template:
suffix += " TEMPLATE {}".format(self._quote_name(template))
return suffix and "WITH" + suffix
def sql_table_creation_suffix(self):
test_settings = self.connection.settings_dict["TEST"]
if test_settings.get("COLLATION") is not None:
raise ImproperlyConfigured(
"PostgreSQL does not support collation setting at database "
"creation time."
)
return self._get_database_create_suffix(
encoding=test_settings["CHARSET"],
template=test_settings.get("TEMPLATE"),
)
def _database_exists(self, cursor, database_name):
cursor.execute(
"SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s",
[strip_quotes(database_name)],
)
return cursor.fetchone() is not None
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
try:
if keepdb and self._database_exists(cursor, parameters["dbname"]):
# If the database should be kept and it already exists, don't
# try to create a new one.
return
super()._execute_create_test_db(cursor, parameters, keepdb)
except Exception as e:
if not isinstance(e.__cause__, errors.DuplicateDatabase):
# All errors except "database already exists" cancel tests.
self.log("Got an error creating the test database: %s" % e)
sys.exit(2)
elif not keepdb:
# If the database should be kept, ignore "database already
# exists".
raise
def _clone_test_db(self, suffix, verbosity, keepdb=False):
# CREATE DATABASE ... WITH TEMPLATE ... requires closing connections
# to the template database.
self.connection.close()
self.connection.close_pool()
source_database_name = self.connection.settings_dict["NAME"]
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
test_db_params = {
"dbname": self._quote_name(target_database_name),
"suffix": self._get_database_create_suffix(template=source_database_name),
}
with self._nodb_cursor() as cursor:
try:
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception:
try:
if verbosity >= 1:
self.log(
"Destroying old test database for alias %s..."
% (
self._get_database_display_str(
verbosity, target_database_name
),
)
)
cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
self.log("Got an error cloning the test database: %s" % e)
sys.exit(2)
def _destroy_test_db(self, test_database_name, verbosity):
self.connection.close_pool()
return super()._destroy_test_db(test_database_name, verbosity)

View file

@ -0,0 +1,170 @@
import operator
from django.db import DataError, InterfaceError
from django.db.backends.base.features import BaseDatabaseFeatures
from django.db.backends.postgresql.psycopg_any import is_psycopg3
from django.utils.functional import cached_property
class DatabaseFeatures(BaseDatabaseFeatures):
minimum_database_version = (14,)
allows_group_by_selected_pks = True
can_return_columns_from_insert = True
can_return_rows_from_bulk_insert = True
has_real_datatype = True
has_native_uuid_field = True
has_native_duration_field = True
has_native_json_field = True
can_defer_constraint_checks = True
has_select_for_update = True
has_select_for_update_nowait = True
has_select_for_update_of = True
has_select_for_update_skip_locked = True
has_select_for_no_key_update = True
can_release_savepoints = True
supports_comments = True
supports_tablespaces = True
supports_transactions = True
can_introspect_materialized_views = True
can_distinct_on_fields = True
can_rollback_ddl = True
schema_editor_uses_clientside_param_binding = True
supports_combined_alters = True
nulls_order_largest = True
closed_cursor_error_class = InterfaceError
greatest_least_ignores_nulls = True
can_clone_databases = True
supports_temporal_subtraction = True
supports_slicing_ordering_in_compound = True
create_test_procedure_without_params_sql = """
CREATE FUNCTION test_procedure () RETURNS void AS $$
DECLARE
V_I INTEGER;
BEGIN
V_I := 1;
END;
$$ LANGUAGE plpgsql;"""
create_test_procedure_with_int_param_sql = """
CREATE FUNCTION test_procedure (P_I INTEGER) RETURNS void AS $$
DECLARE
V_I INTEGER;
BEGIN
V_I := P_I;
END;
$$ LANGUAGE plpgsql;"""
create_test_table_with_composite_primary_key = """
CREATE TABLE test_table_composite_pk (
column_1 INTEGER NOT NULL,
column_2 INTEGER NOT NULL,
PRIMARY KEY(column_1, column_2)
)
"""
requires_casted_case_in_updates = True
supports_over_clause = True
supports_frame_exclusion = True
only_supports_unbounded_with_preceding_and_following = True
supports_aggregate_filter_clause = True
supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"}
supports_deferrable_unique_constraints = True
has_json_operators = True
json_key_contains_list_matching_requires_list = True
supports_update_conflicts = True
supports_update_conflicts_with_target = True
supports_covering_indexes = True
supports_stored_generated_columns = True
supports_virtual_generated_columns = False
can_rename_index = True
test_collations = {
"deterministic": "C",
"non_default": "sv-x-icu",
"swedish_ci": "sv-x-icu",
"virtual": "sv-x-icu",
}
test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'"
insert_test_table_with_defaults = "INSERT INTO {} DEFAULT VALUES"
@cached_property
def django_test_skips(self):
skips = {
"opclasses are PostgreSQL only.": {
"indexes.tests.SchemaIndexesNotPostgreSQLTests."
"test_create_index_ignores_opclasses",
},
"PostgreSQL requires casting to text.": {
"lookup.tests.LookupTests.test_textfield_exact_null",
},
}
if self.connection.settings_dict["OPTIONS"].get("pool"):
skips.update(
{
"Pool does implicit health checks": {
"backends.base.test_base.ConnectionHealthChecksTests."
"test_health_checks_enabled",
"backends.base.test_base.ConnectionHealthChecksTests."
"test_set_autocommit_health_checks_enabled",
},
}
)
if self.uses_server_side_binding:
skips.update(
{
"The actual query cannot be determined for server side bindings": {
"backends.base.test_base.ExecuteWrapperTests."
"test_wrapper_debug",
}
},
)
return skips
@cached_property
def django_test_expected_failures(self):
expected_failures = set()
if self.uses_server_side_binding:
expected_failures.update(
{
# Parameters passed to expressions in SELECT and GROUP BY
# clauses are not recognized as the same values when using
# server-side binding cursors (#34255).
"aggregation.tests.AggregateTestCase."
"test_group_by_nested_expression_with_params",
}
)
return expected_failures
@cached_property
def uses_server_side_binding(self):
options = self.connection.settings_dict["OPTIONS"]
return is_psycopg3 and options.get("server_side_binding") is True
@cached_property
def prohibits_null_characters_in_text_exception(self):
if is_psycopg3:
return DataError, "PostgreSQL text fields cannot contain NUL (0x00) bytes"
else:
return ValueError, "A string literal cannot contain NUL (0x00) characters."
@cached_property
def introspected_field_types(self):
return {
**super().introspected_field_types,
"PositiveBigIntegerField": "BigIntegerField",
"PositiveIntegerField": "IntegerField",
"PositiveSmallIntegerField": "SmallIntegerField",
}
@cached_property
def is_postgresql_15(self):
return self.connection.pg_version >= 150000
@cached_property
def is_postgresql_16(self):
return self.connection.pg_version >= 160000
@cached_property
def is_postgresql_17(self):
return self.connection.pg_version >= 170000
supports_unlimited_charfield = True
supports_nulls_distinct_unique_constraints = property(
operator.attrgetter("is_postgresql_15")
)

View file

@ -0,0 +1,299 @@
from collections import namedtuple
from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
from django.db.backends.base.introspection import TableInfo as BaseTableInfo
from django.db.models import Index
FieldInfo = namedtuple("FieldInfo", BaseFieldInfo._fields + ("is_autofield", "comment"))
TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",))
class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type codes to Django Field types.
data_types_reverse = {
16: "BooleanField",
17: "BinaryField",
20: "BigIntegerField",
21: "SmallIntegerField",
23: "IntegerField",
25: "TextField",
700: "FloatField",
701: "FloatField",
869: "GenericIPAddressField",
1042: "CharField", # blank-padded
1043: "CharField",
1082: "DateField",
1083: "TimeField",
1114: "DateTimeField",
1184: "DateTimeField",
1186: "DurationField",
1266: "TimeField",
1700: "DecimalField",
2950: "UUIDField",
3802: "JSONField",
}
# A hook for subclasses.
index_default_access_method = "btree"
ignored_tables = []
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if description.is_autofield or (
# Required for pre-Django 4.1 serial columns.
description.default
and "nextval" in description.default
):
if field_type == "IntegerField":
return "AutoField"
elif field_type == "BigIntegerField":
return "BigAutoField"
elif field_type == "SmallIntegerField":
return "SmallAutoField"
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute(
"""
SELECT
c.relname,
CASE
WHEN c.relispartition THEN 'p'
WHEN c.relkind IN ('m', 'v') THEN 'v'
ELSE 't'
END,
obj_description(c.oid, 'pg_class')
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
"""
)
return [
TableInfo(*row)
for row in cursor.fetchall()
if row[0] not in self.ignored_tables
]
def get_table_description(self, cursor, table_name):
"""
Return a description of the table with the DB-API cursor.description
interface.
"""
# Query the pg_catalog tables as cursor.description does not reliably
# return the nullable property and information_schema.columns does not
# contain details of materialized views.
cursor.execute(
"""
SELECT
a.attname AS column_name,
NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation,
a.attidentity != '' AS is_autofield,
col_description(a.attrelid, a.attnum) AS column_comment
FROM pg_attribute a
LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
LEFT JOIN pg_collation co ON a.attcollation = co.oid
JOIN pg_type t ON a.atttypid = t.oid
JOIN pg_class c ON a.attrelid = c.oid
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
AND c.relname = %s
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
""",
[table_name],
)
field_map = {line[0]: line[1:] for line in cursor.fetchall()}
cursor.execute(
"SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
)
return [
FieldInfo(
line.name,
line.type_code,
# display_size is always None on psycopg2.
line.internal_size if line.display_size is None else line.display_size,
line.internal_size,
line.precision,
line.scale,
*field_map[line.name],
)
for line in cursor.description
]
def get_sequences(self, cursor, table_name, table_fields=()):
cursor.execute(
"""
SELECT
s.relname AS sequence_name,
a.attname AS colname
FROM
pg_class s
JOIN pg_depend d ON d.objid = s.oid
AND d.classid = 'pg_class'::regclass
AND d.refclassid = 'pg_class'::regclass
JOIN pg_attribute a ON d.refobjid = a.attrelid
AND d.refobjsubid = a.attnum
JOIN pg_class tbl ON tbl.oid = d.refobjid
AND tbl.relname = %s
AND pg_catalog.pg_table_is_visible(tbl.oid)
WHERE
s.relkind = 'S';
""",
[table_name],
)
return [
{"name": row[0], "table": table_name, "column": row[1]}
for row in cursor.fetchall()
]
def get_relations(self, cursor, table_name):
"""
Return a dictionary of {field_name: (field_name_other_table, other_table)}
representing all foreign keys in the given table.
"""
cursor.execute(
"""
SELECT a1.attname, c2.relname, a2.attname
FROM pg_constraint con
LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
LEFT JOIN
pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
LEFT JOIN
pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
WHERE
c1.relname = %s AND
con.contype = 'f' AND
c1.relnamespace = c2.relnamespace AND
pg_catalog.pg_table_is_visible(c1.oid)
""",
[table_name],
)
return {row[0]: (row[2], row[1]) for row in cursor.fetchall()}
def get_constraints(self, cursor, table_name):
"""
Retrieve any constraints or keys (unique, pk, fk, check, index) across
one or more columns. Also retrieve the definition of expression-based
indexes.
"""
constraints = {}
# Loop over the key table, collecting things as constraints. The column
# array must return column names in the same order in which they were
# created.
cursor.execute(
"""
SELECT
c.conname,
array(
SELECT attname
FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
JOIN pg_attribute AS ca ON cols.colid = ca.attnum
WHERE ca.attrelid = c.conrelid
ORDER BY cols.arridx
),
c.contype,
(SELECT fkc.relname || '.' || fka.attname
FROM pg_attribute AS fka
JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
cl.reloptions
FROM pg_constraint AS c
JOIN pg_class AS cl ON c.conrelid = cl.oid
WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
""",
[table_name],
)
for constraint, columns, kind, used_cols, options in cursor.fetchall():
constraints[constraint] = {
"columns": columns,
"primary_key": kind == "p",
"unique": kind in ["p", "u"],
"foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
"check": kind == "c",
"index": False,
"definition": None,
"options": options,
}
# Now get indexes
cursor.execute(
"""
SELECT
indexname,
array_agg(attname ORDER BY arridx),
indisunique,
indisprimary,
array_agg(ordering ORDER BY arridx),
amname,
exprdef,
s2.attoptions
FROM (
SELECT
c2.relname as indexname, idx.*, attr.attname, am.amname,
CASE
WHEN idx.indexprs IS NOT NULL THEN
pg_get_indexdef(idx.indexrelid)
END AS exprdef,
CASE am.amname
WHEN %s THEN
CASE (option & 1)
WHEN 1 THEN 'DESC' ELSE 'ASC'
END
END as ordering,
c2.reloptions as attoptions
FROM (
SELECT *
FROM
pg_index i,
unnest(i.indkey, i.indoption)
WITH ORDINALITY koi(key, option, arridx)
) idx
LEFT JOIN pg_class c ON idx.indrelid = c.oid
LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
LEFT JOIN pg_am am ON c2.relam = am.oid
LEFT JOIN
pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
) s2
GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
""",
[self.index_default_access_method, table_name],
)
for (
index,
columns,
unique,
primary,
orders,
type_,
definition,
options,
) in cursor.fetchall():
if index not in constraints:
basic_index = (
type_ == self.index_default_access_method
and
# '_btree' references
# django.contrib.postgres.indexes.BTreeIndex.suffix.
not index.endswith("_btree")
and options is None
)
constraints[index] = {
"columns": columns if columns != [None] else [],
"orders": orders if orders != [None] else [],
"primary_key": primary,
"unique": unique,
"foreign_key": None,
"check": False,
"index": True,
"type": Index.suffix if basic_index else type_,
"definition": definition,
"options": options,
}
return constraints

View file

@ -0,0 +1,422 @@
import json
from functools import lru_cache, partial
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.postgresql.compiler import InsertUnnest
from django.db.backends.postgresql.psycopg_any import (
Inet,
Jsonb,
errors,
is_psycopg3,
mogrify,
)
from django.db.backends.utils import split_tzname_delta
from django.db.models.constants import OnConflict
from django.db.models.functions import Cast
from django.utils.regex_helper import _lazy_re_compile
@lru_cache
def get_json_dumps(encoder):
if encoder is None:
return json.dumps
return partial(json.dumps, cls=encoder)
class DatabaseOperations(BaseDatabaseOperations):
compiler_module = "django.db.backends.postgresql.compiler"
cast_char_field_without_max_length = "varchar"
explain_prefix = "EXPLAIN"
explain_options = frozenset(
[
"ANALYZE",
"BUFFERS",
"COSTS",
"GENERIC_PLAN",
"MEMORY",
"SETTINGS",
"SERIALIZE",
"SUMMARY",
"TIMING",
"VERBOSE",
"WAL",
]
)
cast_data_types = {
"AutoField": "integer",
"BigAutoField": "bigint",
"SmallAutoField": "smallint",
}
if is_psycopg3:
from psycopg.types import numeric
integerfield_type_map = {
"SmallIntegerField": numeric.Int2,
"IntegerField": numeric.Int4,
"BigIntegerField": numeric.Int8,
"PositiveSmallIntegerField": numeric.Int2,
"PositiveIntegerField": numeric.Int4,
"PositiveBigIntegerField": numeric.Int8,
}
def unification_cast_sql(self, output_field):
internal_type = output_field.get_internal_type()
if internal_type in (
"GenericIPAddressField",
"IPAddressField",
"TimeField",
"UUIDField",
):
# PostgreSQL will resolve a union as type 'text' if input types are
# 'unknown'.
# https://www.postgresql.org/docs/current/typeconv-union-case.html
# These fields cannot be implicitly cast back in the default
# PostgreSQL configuration so we need to explicitly cast them.
# We must also remove components of the type within brackets:
# varchar(255) -> varchar.
return (
"CAST(%%s AS %s)" % output_field.db_type(self.connection).split("(")[0]
)
return "%s"
# EXTRACT format cannot be passed in parameters.
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
def date_extract_sql(self, lookup_type, sql, params):
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
if lookup_type == "week_day":
# For consistency across backends, we return Sunday=1, Saturday=7.
return f"EXTRACT(DOW FROM {sql}) + 1", params
elif lookup_type == "iso_week_day":
return f"EXTRACT(ISODOW FROM {sql})", params
elif lookup_type == "iso_year":
return f"EXTRACT(ISOYEAR FROM {sql})", params
lookup_type = lookup_type.upper()
if not self._extract_format_re.fullmatch(lookup_type):
raise ValueError(f"Invalid lookup type: {lookup_type!r}")
return f"EXTRACT({lookup_type} FROM {sql})", params
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
if offset:
sign = "-" if sign == "+" else "+"
return f"{tzname}{sign}{offset}"
return tzname
def _convert_sql_to_tz(self, sql, params, tzname):
if tzname and settings.USE_TZ:
tzname_param = self._prepare_tzname_delta(tzname)
return f"{sql} AT TIME ZONE %s", (*params, tzname_param)
return sql, params
def datetime_cast_date_sql(self, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
return f"({sql})::date", params
def datetime_cast_time_sql(self, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
return f"({sql})::time", params
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
if lookup_type == "second":
# Truncate fractional seconds.
return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
return self.date_extract_sql(lookup_type, sql, params)
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
def time_extract_sql(self, lookup_type, sql, params):
if lookup_type == "second":
# Truncate fractional seconds.
return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
return self.date_extract_sql(lookup_type, sql, params)
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params)
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"
def bulk_insert_sql(self, fields, placeholder_rows):
if isinstance(placeholder_rows, InsertUnnest):
return f"SELECT * FROM {placeholder_rows}"
return super().bulk_insert_sql(fields, placeholder_rows)
def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the tuple of returned data.
"""
return cursor.fetchall()
def lookup_cast(self, lookup_type, internal_type=None):
lookup = "%s"
# Cast text lookups to text to allow things like filter(x__contains=4)
if lookup_type in (
"iexact",
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
"regex",
"iregex",
):
if internal_type in ("IPAddressField", "GenericIPAddressField"):
lookup = "HOST(%s)"
else:
lookup = "%s::text"
# Use UPPER(x) for case-insensitive lookups; it's faster.
if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
lookup = "UPPER(%s)" % lookup
return lookup
def no_limit_value(self):
return None
def prepare_sql_script(self, sql):
return [sql]
def quote_name(self, name):
if name.startswith('"') and name.endswith('"'):
return name # Quoting once is enough.
return '"%s"' % name
def compose_sql(self, sql, params):
return mogrify(sql, params, self.connection)
def set_time_zone_sql(self):
return "SELECT set_config('TimeZone', %s, false)"
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
if not tables:
return []
# Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us
# to truncate tables referenced by a foreign key in any other table.
sql_parts = [
style.SQL_KEYWORD("TRUNCATE"),
", ".join(style.SQL_FIELD(self.quote_name(table)) for table in tables),
]
if reset_sequences:
sql_parts.append(style.SQL_KEYWORD("RESTART IDENTITY"))
if allow_cascade:
sql_parts.append(style.SQL_KEYWORD("CASCADE"))
return ["%s;" % " ".join(sql_parts)]
def sequence_reset_by_name_sql(self, style, sequences):
# 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
# to reset sequence indices
sql = []
for sequence_info in sequences:
table_name = sequence_info["table"]
# 'id' will be the case if it's an m2m using an autogenerated
# intermediate table (see BaseDatabaseIntrospection.sequence_list).
column_name = sequence_info["column"] or "id"
sql.append(
"%s setval(pg_get_serial_sequence('%s','%s'), 1, false);"
% (
style.SQL_KEYWORD("SELECT"),
style.SQL_TABLE(self.quote_name(table_name)),
style.SQL_FIELD(column_name),
)
)
return sql
def tablespace_sql(self, tablespace, inline=False):
if inline:
return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace)
else:
return "TABLESPACE %s" % self.quote_name(tablespace)
def sequence_reset_sql(self, style, model_list):
from django.db import models
output = []
qn = self.quote_name
for model in model_list:
# Use `coalesce` to set the sequence for each model to the max pk
# value if there are records, or 1 if there are none. Set the
# `is_called` property (the third argument to `setval`) to true if
# there are records (as the max pk value is already in use),
# otherwise set it to false. Use pg_get_serial_sequence to get the
# underlying sequence name from the table name and column name.
for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
output.append(
"%s setval(pg_get_serial_sequence('%s','%s'), "
"coalesce(max(%s), 1), max(%s) %s null) %s %s;"
% (
style.SQL_KEYWORD("SELECT"),
style.SQL_TABLE(qn(model._meta.db_table)),
style.SQL_FIELD(f.column),
style.SQL_FIELD(qn(f.column)),
style.SQL_FIELD(qn(f.column)),
style.SQL_KEYWORD("IS NOT"),
style.SQL_KEYWORD("FROM"),
style.SQL_TABLE(qn(model._meta.db_table)),
)
)
# Only one AutoField is allowed per model, so don't bother
# continuing.
break
return output
def prep_for_iexact_query(self, x):
return x
def max_name_length(self):
"""
Return the maximum length of an identifier.
The maximum length of an identifier is 63 by default, but can be
changed by recompiling PostgreSQL after editing the NAMEDATALEN
macro in src/include/pg_config_manual.h.
This implementation returns 63, but can be overridden by a custom
database backend that inherits most of its behavior from this one.
"""
return 63
def distinct_sql(self, fields, params):
if fields:
params = [param for param_list in params for param in param_list]
return (["DISTINCT ON (%s)" % ", ".join(fields)], params)
else:
return ["DISTINCT"], []
if is_psycopg3:
def last_executed_query(self, cursor, sql, params):
if self.connection.features.uses_server_side_binding:
try:
return self.compose_sql(sql, params)
except errors.DataError:
return None
else:
if cursor._query and cursor._query.query is not None:
return cursor._query.query.decode()
return None
else:
def last_executed_query(self, cursor, sql, params):
# https://www.psycopg.org/docs/cursor.html#cursor.query
# The query attribute is a Psycopg extension to the DB API 2.0.
if cursor.query is not None:
return cursor.query.decode()
return None
def return_insert_columns(self, fields):
if not fields:
return "", ()
columns = [
"%s.%s"
% (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
for field in fields
]
return "RETURNING %s" % ", ".join(columns), ()
if is_psycopg3:
def adapt_integerfield_value(self, value, internal_type):
if value is None or hasattr(value, "resolve_expression"):
return value
return self.integerfield_type_map[internal_type](value)
def adapt_datefield_value(self, value):
return value
def adapt_datetimefield_value(self, value):
return value
def adapt_timefield_value(self, value):
return value
def adapt_ipaddressfield_value(self, value):
if value:
return Inet(value)
return None
def adapt_json_value(self, value, encoder):
return Jsonb(value, dumps=get_json_dumps(encoder))
def subtract_temporals(self, internal_type, lhs, rhs):
if internal_type == "DateField":
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
return "(interval '1 day' * (%s - %s))" % (lhs_sql, rhs_sql), params
return super().subtract_temporals(internal_type, lhs, rhs)
def explain_query_prefix(self, format=None, **options):
extra = {}
if serialize := options.pop("serialize", None):
if serialize.upper() in {"TEXT", "BINARY"}:
extra["SERIALIZE"] = serialize.upper()
# Normalize options.
if options:
options = {
name.upper(): "true" if value else "false"
for name, value in options.items()
}
for valid_option in self.explain_options:
value = options.pop(valid_option, None)
if value is not None:
extra[valid_option] = value
prefix = super().explain_query_prefix(format, **options)
if format:
extra["FORMAT"] = format
if extra:
prefix += " (%s)" % ", ".join("%s %s" % i for i in extra.items())
return prefix
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if on_conflict == OnConflict.IGNORE:
return "ON CONFLICT DO NOTHING"
if on_conflict == OnConflict.UPDATE:
return "ON CONFLICT(%s) DO UPDATE SET %s" % (
", ".join(map(self.quote_name, unique_fields)),
", ".join(
[
f"{field} = EXCLUDED.{field}"
for field in map(self.quote_name, update_fields)
]
),
)
return super().on_conflict_suffix_sql(
fields,
on_conflict,
update_fields,
unique_fields,
)
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
lhs_expr, rhs_expr = super().prepare_join_on_clause(
lhs_table, lhs_field, rhs_table, rhs_field
)
if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection):
rhs_expr = Cast(rhs_expr, lhs_field)
return lhs_expr, rhs_expr

View file

@ -0,0 +1,114 @@
import ipaddress
from functools import lru_cache
try:
from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors, sql
from psycopg.postgres import types
from psycopg.types.datetime import TimestamptzLoader
from psycopg.types.json import Jsonb
from psycopg.types.range import Range, RangeDumper
from psycopg.types.string import TextLoader
Inet = ipaddress.ip_address
DateRange = DateTimeRange = DateTimeTZRange = NumericRange = Range
RANGE_TYPES = (Range,)
TSRANGE_OID = types["tsrange"].oid
TSTZRANGE_OID = types["tstzrange"].oid
def mogrify(sql, params, connection):
with connection.cursor() as cursor:
return ClientCursor(cursor.connection).mogrify(sql, params)
# Adapters.
class BaseTzLoader(TimestamptzLoader):
"""
Load a PostgreSQL timestamptz using the a specific timezone.
The timezone can be None too, in which case it will be chopped.
"""
timezone = None
def load(self, data):
res = super().load(data)
return res.replace(tzinfo=self.timezone)
def register_tzloader(tz, context):
class SpecificTzLoader(BaseTzLoader):
timezone = tz
context.adapters.register_loader("timestamptz", SpecificTzLoader)
class DjangoRangeDumper(RangeDumper):
"""A Range dumper customized for Django."""
def upgrade(self, obj, format):
# Dump ranges containing naive datetimes as tstzrange, because
# Django doesn't use tz-aware ones.
dumper = super().upgrade(obj, format)
if dumper is not self and dumper.oid == TSRANGE_OID:
dumper.oid = TSTZRANGE_OID
return dumper
@lru_cache
def get_adapters_template(use_tz, timezone):
# Create at adapters map extending the base one.
ctx = adapt.AdaptersMap(adapters)
# Register a no-op dumper to avoid a round trip from psycopg version 3
# decode to json.dumps() to json.loads(), when using a custom decoder
# in JSONField.
ctx.register_loader("jsonb", TextLoader)
# Don't convert automatically from PostgreSQL network types to Python
# ipaddress.
ctx.register_loader("inet", TextLoader)
ctx.register_loader("cidr", TextLoader)
ctx.register_dumper(Range, DjangoRangeDumper)
# Register a timestamptz loader configured on self.timezone.
# This, however, can be overridden by create_cursor.
register_tzloader(timezone, ctx)
return ctx
is_psycopg3 = True
except ImportError:
from enum import IntEnum
from psycopg2 import errors, extensions, sql # NOQA
from psycopg2.extras import ( # NOQA
DateRange,
DateTimeRange,
DateTimeTZRange,
Inet,
Json,
NumericRange,
Range,
)
RANGE_TYPES = (DateRange, DateTimeRange, DateTimeTZRange, NumericRange)
class IsolationLevel(IntEnum):
READ_UNCOMMITTED = extensions.ISOLATION_LEVEL_READ_UNCOMMITTED
READ_COMMITTED = extensions.ISOLATION_LEVEL_READ_COMMITTED
REPEATABLE_READ = extensions.ISOLATION_LEVEL_REPEATABLE_READ
SERIALIZABLE = extensions.ISOLATION_LEVEL_SERIALIZABLE
def _quote(value, connection=None):
adapted = extensions.adapt(value)
if hasattr(adapted, "encoding"):
adapted.encoding = "utf8"
# getquoted() returns a quoted bytestring of the adapted value.
return adapted.getquoted().decode()
sql.quote = _quote
def mogrify(sql, params, connection):
with connection.cursor() as cursor:
return cursor.mogrify(sql, params).decode()
is_psycopg3 = False
class Jsonb(Json):
def getquoted(self):
quoted = super().getquoted()
return quoted + b"::jsonb"

View file

@ -0,0 +1,380 @@
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.backends.ddl_references import IndexColumns
from django.db.backends.postgresql.psycopg_any import sql
from django.db.backends.utils import strip_quotes
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Setting all constraints to IMMEDIATE to allow changing data in the same
# transaction.
sql_update_with_default = (
"UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
"; SET CONSTRAINTS ALL IMMEDIATE"
)
sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
sql_create_index = (
"CREATE INDEX %(name)s ON %(table)s%(using)s "
"(%(columns)s)%(include)s%(extra)s%(condition)s"
)
sql_create_index_concurrently = (
"CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
"(%(columns)s)%(include)s%(extra)s%(condition)s"
)
sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
# Setting the constraint to IMMEDIATE to allow changing data in the same
# transaction.
sql_create_column_inline_fk = (
"CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
"; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
)
# Setting the constraint to IMMEDIATE runs any deferred checks to allow
# dropping it in the same transaction.
sql_delete_fk = (
"SET CONSTRAINTS %(name)s IMMEDIATE; "
"ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
)
sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)"
def execute(self, sql, params=()):
# Merge the query client-side, as PostgreSQL won't do it server-side.
if params is None:
return super().execute(sql, params)
sql = self.connection.ops.compose_sql(str(sql), params)
# Don't let the superclass touch anything.
return super().execute(sql, None)
sql_add_identity = (
"ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
"GENERATED BY DEFAULT AS IDENTITY"
)
sql_drop_indentity = (
"ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
)
def quote_value(self, value):
return sql.quote(value, self.connection.connection)
def _field_indexes_sql(self, model, field):
output = super()._field_indexes_sql(model, field)
like_index_statement = self._create_like_index_sql(model, field)
if like_index_statement is not None:
output.append(like_index_statement)
return output
def _field_data_type(self, field):
if field.is_relation:
return field.rel_db_type(self.connection)
return self.connection.data_types.get(
field.get_internal_type(),
field.db_type(self.connection),
)
def _field_base_data_types(self, field):
# Yield base data types for array fields.
if field.base_field.get_internal_type() == "ArrayField":
yield from self._field_base_data_types(field.base_field)
else:
yield self._field_data_type(field.base_field)
def _create_like_index_sql(self, model, field):
"""
Return the statement to create an index with varchar operator pattern
when the column type is 'varchar' or 'text', otherwise return None.
"""
db_type = field.db_type(connection=self.connection)
if db_type is not None and (field.db_index or field.unique):
# Fields with database column types of `varchar` and `text` need
# a second index that specifies their operator class, which is
# needed when performing correct LIKE queries outside the
# C locale. See #12234.
#
# The same doesn't apply to array fields such as varchar[size]
# and text[size], so skip them.
if "[" in db_type:
return None
# Non-deterministic collations on Postgresql don't support indexes
# for operator classes varchar_pattern_ops/text_pattern_ops.
collation_name = getattr(field, "db_collation", None)
if not collation_name and field.is_relation:
collation_name = getattr(field.target_field, "db_collation", None)
if collation_name and not self._is_collation_deterministic(collation_name):
return None
if db_type.startswith("varchar"):
return self._create_index_sql(
model,
fields=[field],
suffix="_like",
opclasses=["varchar_pattern_ops"],
)
elif db_type.startswith("text"):
return self._create_index_sql(
model,
fields=[field],
suffix="_like",
opclasses=["text_pattern_ops"],
)
return None
def _using_sql(self, new_field, old_field):
if new_field.generated:
return ""
using_sql = " USING %(column)s::%(type)s"
new_internal_type = new_field.get_internal_type()
old_internal_type = old_field.get_internal_type()
if new_internal_type == "ArrayField" and new_internal_type == old_internal_type:
# Compare base data types for array fields.
if list(self._field_base_data_types(old_field)) != list(
self._field_base_data_types(new_field)
):
return using_sql
elif self._field_data_type(old_field) != self._field_data_type(new_field):
return using_sql
return ""
def _get_sequence_name(self, table, column):
with self.connection.cursor() as cursor:
for sequence in self.connection.introspection.get_sequences(cursor, table):
if sequence["column"] == column:
return sequence["name"]
return None
def _is_changing_type_of_indexed_text_column(self, old_field, old_type, new_type):
return (old_field.db_index or old_field.unique) and (
(old_type.startswith("varchar") and not new_type.startswith("varchar"))
or (old_type.startswith("text") and not new_type.startswith("text"))
or (old_type.startswith("citext") and not new_type.startswith("citext"))
)
def _alter_column_type_sql(
self, model, old_field, new_field, new_type, old_collation, new_collation
):
# Drop indexes on varchar/text/citext columns that are changing to a
# different type.
old_db_params = old_field.db_parameters(connection=self.connection)
old_type = old_db_params["type"]
if self._is_changing_type_of_indexed_text_column(old_field, old_type, new_type):
index_name = self._create_index_name(
model._meta.db_table, [old_field.column], suffix="_like"
)
self.execute(self._delete_index_sql(model, index_name))
self.sql_alter_column_type = (
"ALTER COLUMN %(column)s TYPE %(type)s%(collation)s"
)
# Cast when data type changed.
if using_sql := self._using_sql(new_field, old_field):
self.sql_alter_column_type += using_sql
new_internal_type = new_field.get_internal_type()
old_internal_type = old_field.get_internal_type()
# Make ALTER TYPE with IDENTITY make sense.
table = strip_quotes(model._meta.db_table)
auto_field_types = {
"AutoField",
"BigAutoField",
"SmallAutoField",
}
old_is_auto = old_internal_type in auto_field_types
new_is_auto = new_internal_type in auto_field_types
if new_is_auto and not old_is_auto:
column = strip_quotes(new_field.column)
return (
(
self.sql_alter_column_type
% {
"column": self.quote_name(column),
"type": new_type,
"collation": "",
},
[],
),
[
(
self.sql_add_identity
% {
"table": self.quote_name(table),
"column": self.quote_name(column),
},
[],
),
],
)
elif old_is_auto and not new_is_auto:
# Drop IDENTITY if exists (pre-Django 4.1 serial columns don't have
# it).
self.execute(
self.sql_drop_indentity
% {
"table": self.quote_name(table),
"column": self.quote_name(strip_quotes(new_field.column)),
}
)
column = strip_quotes(new_field.column)
fragment, _ = super()._alter_column_type_sql(
model, old_field, new_field, new_type, old_collation, new_collation
)
# Drop the sequence if exists (Django 4.1+ identity columns don't
# have it).
other_actions = []
if sequence_name := self._get_sequence_name(table, column):
other_actions = [
(
self.sql_delete_sequence
% {
"sequence": self.quote_name(sequence_name),
},
[],
)
]
return fragment, other_actions
elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
fragment, _ = super()._alter_column_type_sql(
model, old_field, new_field, new_type, old_collation, new_collation
)
column = strip_quotes(new_field.column)
db_types = {
"AutoField": "integer",
"BigAutoField": "bigint",
"SmallAutoField": "smallint",
}
# Alter the sequence type if exists (Django 4.1+ identity columns
# don't have it).
other_actions = []
if sequence_name := self._get_sequence_name(table, column):
other_actions = [
(
self.sql_alter_sequence_type
% {
"sequence": self.quote_name(sequence_name),
"type": db_types[new_internal_type],
},
[],
),
]
return fragment, other_actions
else:
return super()._alter_column_type_sql(
model, old_field, new_field, new_type, old_collation, new_collation
)
def _alter_field(
self,
model,
old_field,
new_field,
old_type,
new_type,
old_db_params,
new_db_params,
strict=False,
):
super()._alter_field(
model,
old_field,
new_field,
old_type,
new_type,
old_db_params,
new_db_params,
strict,
)
# Added an index? Create any PostgreSQL-specific indexes.
if (
(not (old_field.db_index or old_field.unique) and new_field.db_index)
or (not old_field.unique and new_field.unique)
or (
self._is_changing_type_of_indexed_text_column(
old_field, old_type, new_type
)
)
):
like_index_statement = self._create_like_index_sql(model, new_field)
if like_index_statement is not None:
self.execute(like_index_statement)
# Removed an index? Drop any PostgreSQL-specific indexes.
if old_field.unique and not (new_field.db_index or new_field.unique):
index_to_remove = self._create_index_name(
model._meta.db_table, [old_field.column], suffix="_like"
)
self.execute(self._delete_index_sql(model, index_to_remove))
def _index_columns(self, table, columns, col_suffixes, opclasses):
if opclasses:
return IndexColumns(
table,
columns,
self.quote_name,
col_suffixes=col_suffixes,
opclasses=opclasses,
)
return super()._index_columns(table, columns, col_suffixes, opclasses)
def add_index(self, model, index, concurrently=False):
self.execute(
index.create_sql(model, self, concurrently=concurrently), params=None
)
def remove_index(self, model, index, concurrently=False):
self.execute(index.remove_sql(model, self, concurrently=concurrently))
def _delete_index_sql(self, model, name, sql=None, concurrently=False):
sql = (
self.sql_delete_index_concurrently
if concurrently
else self.sql_delete_index
)
return super()._delete_index_sql(model, name, sql)
def _create_index_sql(
self,
model,
*,
fields=None,
name=None,
suffix="",
using="",
db_tablespace=None,
col_suffixes=(),
sql=None,
opclasses=(),
condition=None,
concurrently=False,
include=None,
expressions=None,
):
sql = sql or (
self.sql_create_index
if not concurrently
else self.sql_create_index_concurrently
)
return super()._create_index_sql(
model,
fields=fields,
name=name,
suffix=suffix,
using=using,
db_tablespace=db_tablespace,
col_suffixes=col_suffixes,
sql=sql,
opclasses=opclasses,
condition=condition,
include=include,
expressions=expressions,
)
def _is_collation_deterministic(self, collation_name):
with self.connection.cursor() as cursor:
cursor.execute(
"""
SELECT collisdeterministic
FROM pg_collation
WHERE collname = %s
""",
[collation_name],
)
row = cursor.fetchone()
return row[0] if row else None