You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
396 lines
14 KiB
396 lines
14 KiB
import itertools
|
|
|
|
from django.core.exceptions import EmptyResultSet
|
|
from django.db import models
|
|
from django.db.models.expressions import (
|
|
ColPairs,
|
|
Exists,
|
|
Func,
|
|
ResolvedOuterRef,
|
|
Subquery,
|
|
Value,
|
|
)
|
|
from django.db.models.lookups import (
|
|
Exact,
|
|
GreaterThan,
|
|
GreaterThanOrEqual,
|
|
In,
|
|
IsNull,
|
|
LessThan,
|
|
LessThanOrEqual,
|
|
)
|
|
from django.db.models.sql import Query
|
|
from django.db.models.sql.where import AND, OR, WhereNode
|
|
|
|
|
|
class Tuple(Func):
|
|
allows_composite_expressions = True
|
|
function = ""
|
|
output_field = models.Field()
|
|
|
|
def __len__(self):
|
|
return len(self.source_expressions)
|
|
|
|
def __iter__(self):
|
|
return iter(self.source_expressions)
|
|
|
|
|
|
class TupleLookupMixin:
|
|
allows_composite_expressions = True
|
|
|
|
def get_prep_lookup(self):
|
|
if self.rhs_is_direct_value():
|
|
self.check_rhs_is_tuple_or_list()
|
|
self.check_rhs_length_equals_lhs_length()
|
|
else:
|
|
self.check_rhs_is_supported_expression()
|
|
super().get_prep_lookup()
|
|
return self.rhs
|
|
|
|
def check_rhs_is_tuple_or_list(self):
|
|
if not isinstance(self.rhs, (tuple, list)):
|
|
lhs_str = self.get_lhs_str()
|
|
raise ValueError(
|
|
f"{self.lookup_name!r} lookup of {lhs_str} must be a tuple or a list"
|
|
)
|
|
|
|
def check_rhs_length_equals_lhs_length(self):
|
|
len_lhs = len(self.lhs)
|
|
if len_lhs != len(self.rhs):
|
|
lhs_str = self.get_lhs_str()
|
|
raise ValueError(
|
|
f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements"
|
|
)
|
|
|
|
def check_rhs_is_supported_expression(self):
|
|
if not isinstance(self.rhs, (ResolvedOuterRef, Query)):
|
|
lhs_str = self.get_lhs_str()
|
|
rhs_cls = self.rhs.__class__.__name__
|
|
raise ValueError(
|
|
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
|
|
f"only supports OuterRef and QuerySet objects (received {rhs_cls!r})"
|
|
)
|
|
|
|
def get_lhs_str(self):
|
|
if isinstance(self.lhs, ColPairs):
|
|
return repr(self.lhs.field.name)
|
|
else:
|
|
names = ", ".join(repr(f.name) for f in self.lhs)
|
|
return f"({names})"
|
|
|
|
def get_prep_lhs(self):
|
|
if isinstance(self.lhs, (tuple, list)):
|
|
return Tuple(*self.lhs)
|
|
return super().get_prep_lhs()
|
|
|
|
def process_lhs(self, compiler, connection, lhs=None):
|
|
sql, params = super().process_lhs(compiler, connection, lhs)
|
|
if not isinstance(self.lhs, Tuple):
|
|
sql = f"({sql})"
|
|
return sql, params
|
|
|
|
def process_rhs(self, compiler, connection):
|
|
if self.rhs_is_direct_value():
|
|
args = [
|
|
(
|
|
val
|
|
if hasattr(val, "as_sql")
|
|
else Value(val, output_field=col.output_field)
|
|
)
|
|
for col, val in zip(self.lhs, self.rhs)
|
|
]
|
|
return compiler.compile(Tuple(*args))
|
|
else:
|
|
sql, params = compiler.compile(self.rhs)
|
|
if isinstance(self.rhs, ColPairs):
|
|
return "(%s)" % sql, params
|
|
elif isinstance(self.rhs, Query):
|
|
return super().process_rhs(compiler, connection)
|
|
else:
|
|
raise ValueError(
|
|
"Composite field lookups only work with composite expressions."
|
|
)
|
|
|
|
def get_fallback_sql(self, compiler, connection):
|
|
raise NotImplementedError(
|
|
f"{self.__class__.__name__}.get_fallback_sql() must be implemented "
|
|
f"for backends that don't have the supports_tuple_lookups feature enabled."
|
|
)
|
|
|
|
def as_sql(self, compiler, connection):
|
|
if not connection.features.supports_tuple_lookups:
|
|
return self.get_fallback_sql(compiler, connection)
|
|
return super().as_sql(compiler, connection)
|
|
|
|
|
|
class TupleExact(TupleLookupMixin, Exact):
|
|
def get_fallback_sql(self, compiler, connection):
|
|
if isinstance(self.rhs, Query):
|
|
return super(TupleLookupMixin, self).as_sql(compiler, connection)
|
|
# Process right-hand-side to trigger sanitization.
|
|
self.process_rhs(compiler, connection)
|
|
# e.g.: (a, b, c) == (x, y, z) as SQL:
|
|
# WHERE a = x AND b = y AND c = z
|
|
lookups = [Exact(col, val) for col, val in zip(self.lhs, self.rhs)]
|
|
root = WhereNode(lookups, connector=AND)
|
|
|
|
return root.as_sql(compiler, connection)
|
|
|
|
|
|
class TupleIsNull(TupleLookupMixin, IsNull):
|
|
def get_prep_lookup(self):
|
|
rhs = self.rhs
|
|
if isinstance(rhs, (tuple, list)) and len(rhs) == 1:
|
|
rhs = rhs[0]
|
|
if isinstance(rhs, bool):
|
|
return rhs
|
|
raise ValueError(
|
|
"The QuerySet value for an isnull lookup must be True or False."
|
|
)
|
|
|
|
def as_sql(self, compiler, connection):
|
|
# e.g.: (a, b, c) is None as SQL:
|
|
# WHERE a IS NULL OR b IS NULL OR c IS NULL
|
|
# e.g.: (a, b, c) is not None as SQL:
|
|
# WHERE a IS NOT NULL AND b IS NOT NULL AND c IS NOT NULL
|
|
rhs = self.rhs
|
|
lookups = [IsNull(col, rhs) for col in self.lhs]
|
|
root = WhereNode(lookups, connector=OR if rhs else AND)
|
|
return root.as_sql(compiler, connection)
|
|
|
|
|
|
class TupleGreaterThan(TupleLookupMixin, GreaterThan):
|
|
def get_fallback_sql(self, compiler, connection):
|
|
# Process right-hand-side to trigger sanitization.
|
|
self.process_rhs(compiler, connection)
|
|
# e.g.: (a, b, c) > (x, y, z) as SQL:
|
|
# WHERE a > x OR (a = x AND (b > y OR (b = y AND c > z)))
|
|
lookups = itertools.cycle([GreaterThan, Exact])
|
|
connectors = itertools.cycle([OR, AND])
|
|
cols_list = [col for col in self.lhs for _ in range(2)]
|
|
vals_list = [val for val in self.rhs for _ in range(2)]
|
|
cols_iter = iter(cols_list[:-1])
|
|
vals_iter = iter(vals_list[:-1])
|
|
col = next(cols_iter)
|
|
val = next(vals_iter)
|
|
lookup = next(lookups)
|
|
connector = next(connectors)
|
|
root = node = WhereNode([lookup(col, val)], connector=connector)
|
|
|
|
for col, val in zip(cols_iter, vals_iter):
|
|
lookup = next(lookups)
|
|
connector = next(connectors)
|
|
child = WhereNode([lookup(col, val)], connector=connector)
|
|
node.children.append(child)
|
|
node = child
|
|
|
|
return root.as_sql(compiler, connection)
|
|
|
|
|
|
class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual):
|
|
def get_fallback_sql(self, compiler, connection):
|
|
# Process right-hand-side to trigger sanitization.
|
|
self.process_rhs(compiler, connection)
|
|
# e.g.: (a, b, c) >= (x, y, z) as SQL:
|
|
# WHERE a > x OR (a = x AND (b > y OR (b = y AND (c > z OR c = z))))
|
|
lookups = itertools.cycle([GreaterThan, Exact])
|
|
connectors = itertools.cycle([OR, AND])
|
|
cols_list = [col for col in self.lhs for _ in range(2)]
|
|
vals_list = [val for val in self.rhs for _ in range(2)]
|
|
cols_iter = iter(cols_list)
|
|
vals_iter = iter(vals_list)
|
|
col = next(cols_iter)
|
|
val = next(vals_iter)
|
|
lookup = next(lookups)
|
|
connector = next(connectors)
|
|
root = node = WhereNode([lookup(col, val)], connector=connector)
|
|
|
|
for col, val in zip(cols_iter, vals_iter):
|
|
lookup = next(lookups)
|
|
connector = next(connectors)
|
|
child = WhereNode([lookup(col, val)], connector=connector)
|
|
node.children.append(child)
|
|
node = child
|
|
|
|
return root.as_sql(compiler, connection)
|
|
|
|
|
|
class TupleLessThan(TupleLookupMixin, LessThan):
|
|
def get_fallback_sql(self, compiler, connection):
|
|
# Process right-hand-side to trigger sanitization.
|
|
self.process_rhs(compiler, connection)
|
|
# e.g.: (a, b, c) < (x, y, z) as SQL:
|
|
# WHERE a < x OR (a = x AND (b < y OR (b = y AND c < z)))
|
|
lookups = itertools.cycle([LessThan, Exact])
|
|
connectors = itertools.cycle([OR, AND])
|
|
cols_list = [col for col in self.lhs for _ in range(2)]
|
|
vals_list = [val for val in self.rhs for _ in range(2)]
|
|
cols_iter = iter(cols_list[:-1])
|
|
vals_iter = iter(vals_list[:-1])
|
|
col = next(cols_iter)
|
|
val = next(vals_iter)
|
|
lookup = next(lookups)
|
|
connector = next(connectors)
|
|
root = node = WhereNode([lookup(col, val)], connector=connector)
|
|
|
|
for col, val in zip(cols_iter, vals_iter):
|
|
lookup = next(lookups)
|
|
connector = next(connectors)
|
|
child = WhereNode([lookup(col, val)], connector=connector)
|
|
node.children.append(child)
|
|
node = child
|
|
|
|
return root.as_sql(compiler, connection)
|
|
|
|
|
|
class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
|
|
def get_fallback_sql(self, compiler, connection):
|
|
# Process right-hand-side to trigger sanitization.
|
|
self.process_rhs(compiler, connection)
|
|
# e.g.: (a, b, c) <= (x, y, z) as SQL:
|
|
# WHERE a < x OR (a = x AND (b < y OR (b = y AND (c < z OR c = z))))
|
|
lookups = itertools.cycle([LessThan, Exact])
|
|
connectors = itertools.cycle([OR, AND])
|
|
cols_list = [col for col in self.lhs for _ in range(2)]
|
|
vals_list = [val for val in self.rhs for _ in range(2)]
|
|
cols_iter = iter(cols_list)
|
|
vals_iter = iter(vals_list)
|
|
col = next(cols_iter)
|
|
val = next(vals_iter)
|
|
lookup = next(lookups)
|
|
connector = next(connectors)
|
|
root = node = WhereNode([lookup(col, val)], connector=connector)
|
|
|
|
for col, val in zip(cols_iter, vals_iter):
|
|
lookup = next(lookups)
|
|
connector = next(connectors)
|
|
child = WhereNode([lookup(col, val)], connector=connector)
|
|
node.children.append(child)
|
|
node = child
|
|
|
|
return root.as_sql(compiler, connection)
|
|
|
|
|
|
class TupleIn(TupleLookupMixin, In):
|
|
def get_prep_lookup(self):
|
|
if self.rhs_is_direct_value():
|
|
self.check_rhs_is_tuple_or_list()
|
|
self.check_rhs_is_collection_of_tuples_or_lists()
|
|
self.check_rhs_elements_length_equals_lhs_length()
|
|
else:
|
|
self.check_rhs_is_query()
|
|
super(TupleLookupMixin, self).get_prep_lookup()
|
|
|
|
return self.rhs # skip checks from mixin
|
|
|
|
def check_rhs_is_collection_of_tuples_or_lists(self):
|
|
if not all(isinstance(vals, (tuple, list)) for vals in self.rhs):
|
|
lhs_str = self.get_lhs_str()
|
|
raise ValueError(
|
|
f"{self.lookup_name!r} lookup of {lhs_str} "
|
|
"must be a collection of tuples or lists"
|
|
)
|
|
|
|
def check_rhs_elements_length_equals_lhs_length(self):
|
|
len_lhs = len(self.lhs)
|
|
if not all(len_lhs == len(vals) for vals in self.rhs):
|
|
lhs_str = self.get_lhs_str()
|
|
raise ValueError(
|
|
f"{self.lookup_name!r} lookup of {lhs_str} "
|
|
f"must have {len_lhs} elements each"
|
|
)
|
|
|
|
def check_rhs_is_query(self):
|
|
if not isinstance(self.rhs, (Query, Subquery)):
|
|
lhs_str = self.get_lhs_str()
|
|
rhs_cls = self.rhs.__class__.__name__
|
|
raise ValueError(
|
|
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
|
|
f"must be a Query object (received {rhs_cls!r})"
|
|
)
|
|
|
|
def process_rhs(self, compiler, connection):
|
|
if not self.rhs_is_direct_value():
|
|
return super(TupleLookupMixin, self).process_rhs(compiler, connection)
|
|
|
|
rhs = self.rhs
|
|
if not rhs:
|
|
raise EmptyResultSet
|
|
|
|
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
|
|
# WHERE (a, b, c) IN ((x1, y1, z1), (x2, y2, z2))
|
|
result = []
|
|
lhs = self.lhs
|
|
|
|
for vals in rhs:
|
|
# Remove any tuple containing None from the list as NULL is never
|
|
# equal to anything.
|
|
if any(val is None for val in vals):
|
|
continue
|
|
result.append(
|
|
Tuple(
|
|
*[
|
|
(
|
|
val
|
|
if hasattr(val, "as_sql")
|
|
else Value(val, output_field=col.output_field)
|
|
)
|
|
for col, val in zip(lhs, vals)
|
|
]
|
|
)
|
|
)
|
|
|
|
if not result:
|
|
raise EmptyResultSet
|
|
|
|
return compiler.compile(Tuple(*result))
|
|
|
|
def get_fallback_sql(self, compiler, connection):
|
|
rhs = self.rhs
|
|
if not rhs:
|
|
raise EmptyResultSet
|
|
if isinstance(rhs, Query):
|
|
rhs_exprs = itertools.chain.from_iterable(
|
|
(
|
|
select_expr
|
|
if isinstance((select_expr := select[0]), ColPairs)
|
|
else [select_expr]
|
|
)
|
|
for select in rhs.get_compiler(connection=connection).get_select()[0]
|
|
)
|
|
rhs = rhs.clone()
|
|
rhs.add_q(
|
|
models.Q(*[Exact(col, val) for col, val in zip(self.lhs, rhs_exprs)])
|
|
)
|
|
return compiler.compile(Exists(rhs))
|
|
elif not self.rhs_is_direct_value():
|
|
return super(TupleLookupMixin, self).as_sql(compiler, connection)
|
|
|
|
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
|
|
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
|
|
root = WhereNode([], connector=OR)
|
|
lhs = self.lhs
|
|
|
|
for vals in rhs:
|
|
# Remove any tuple containing None from the list as NULL is never
|
|
# equal to anything.
|
|
if any(val is None for val in vals):
|
|
continue
|
|
lookups = [Exact(col, val) for col, val in zip(lhs, vals)]
|
|
root.children.append(WhereNode(lookups, connector=AND))
|
|
|
|
if not root.children:
|
|
raise EmptyResultSet
|
|
return root.as_sql(compiler, connection)
|
|
|
|
|
|
tuple_lookups = {
|
|
"exact": TupleExact,
|
|
"gt": TupleGreaterThan,
|
|
"gte": TupleGreaterThanOrEqual,
|
|
"lt": TupleLessThan,
|
|
"lte": TupleLessThanOrEqual,
|
|
"in": TupleIn,
|
|
"isnull": TupleIsNull,
|
|
}
|