You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
371 lines
12 KiB
371 lines
12 KiB
import copy
|
|
import inspect
|
|
import threading
|
|
import warnings
|
|
from collections import OrderedDict
|
|
|
|
from django.conf import settings
|
|
from django.core.exceptions import ImproperlyConfigured
|
|
from django.utils.module_loading import module_has_submodule
|
|
|
|
from haystack import constants
|
|
from haystack.exceptions import NotHandled, SearchFieldError
|
|
from haystack.utils import importlib
|
|
from haystack.utils.app_loading import haystack_get_app_modules
|
|
|
|
|
|
def import_class(path):
|
|
path_bits = path.split(".")
|
|
# Cut off the class name at the end.
|
|
class_name = path_bits.pop()
|
|
module_path = ".".join(path_bits)
|
|
module_itself = importlib.import_module(module_path)
|
|
|
|
if not hasattr(module_itself, class_name):
|
|
raise ImportError(
|
|
"The Python module '%s' has no '%s' class." % (module_path, class_name)
|
|
)
|
|
|
|
return getattr(module_itself, class_name)
|
|
|
|
|
|
# Load the search backend.
|
|
def load_backend(full_backend_path):
|
|
"""
|
|
Loads a backend for interacting with the search engine.
|
|
|
|
Requires a ``backend_path``. It should be a string resembling a Python
|
|
import path, pointing to a ``BaseEngine`` subclass. The built-in options
|
|
available include::
|
|
|
|
* haystack.backends.solr.SolrEngine
|
|
* haystack.backends.xapian.XapianEngine (third-party)
|
|
* haystack.backends.whoosh.WhooshEngine
|
|
* haystack.backends.simple.SimpleEngine
|
|
|
|
If you've implemented a custom backend, you can provide the path to
|
|
your backend & matching ``Engine`` class. For example::
|
|
|
|
``myapp.search_backends.CustomSolrEngine``
|
|
|
|
"""
|
|
path_bits = full_backend_path.split(".")
|
|
|
|
if len(path_bits) < 2:
|
|
raise ImproperlyConfigured(
|
|
"The provided backend '%s' is not a complete Python path to a BaseEngine subclass."
|
|
% full_backend_path
|
|
)
|
|
|
|
return import_class(full_backend_path)
|
|
|
|
|
|
def load_router(full_router_path):
|
|
"""
|
|
Loads a router for choosing which connection to use.
|
|
|
|
Requires a ``full_router_path``. It should be a string resembling a Python
|
|
import path, pointing to a ``BaseRouter`` subclass. The built-in options
|
|
available include::
|
|
|
|
* haystack.routers.DefaultRouter
|
|
|
|
If you've implemented a custom backend, you can provide the path to
|
|
your backend & matching ``Engine`` class. For example::
|
|
|
|
``myapp.search_routers.MasterSlaveRouter``
|
|
|
|
"""
|
|
path_bits = full_router_path.split(".")
|
|
|
|
if len(path_bits) < 2:
|
|
raise ImproperlyConfigured(
|
|
"The provided router '%s' is not a complete Python path to a BaseRouter subclass."
|
|
% full_router_path
|
|
)
|
|
|
|
return import_class(full_router_path)
|
|
|
|
|
|
class ConnectionHandler:
|
|
def __init__(self, connections_info):
|
|
self.connections_info = connections_info
|
|
self.thread_local = threading.local()
|
|
self._index = None
|
|
|
|
def ensure_defaults(self, alias):
|
|
try:
|
|
conn = self.connections_info[alias]
|
|
except KeyError:
|
|
raise ImproperlyConfigured(
|
|
"The key '%s' isn't an available connection." % alias
|
|
)
|
|
|
|
if not conn.get("ENGINE"):
|
|
conn["ENGINE"] = "haystack.backends.simple_backend.SimpleEngine"
|
|
|
|
def __getitem__(self, key):
|
|
if not hasattr(self.thread_local, "connections"):
|
|
self.thread_local.connections = {}
|
|
elif key in self.thread_local.connections:
|
|
return self.thread_local.connections[key]
|
|
|
|
self.ensure_defaults(key)
|
|
self.thread_local.connections[key] = load_backend(
|
|
self.connections_info[key]["ENGINE"]
|
|
)(using=key)
|
|
return self.thread_local.connections[key]
|
|
|
|
def reload(self, key):
|
|
if not hasattr(self.thread_local, "connections"):
|
|
self.thread_local.connections = {}
|
|
try:
|
|
del self.thread_local.connections[key]
|
|
except KeyError:
|
|
pass
|
|
|
|
return self.__getitem__(key)
|
|
|
|
def all(self): # noqa A003
|
|
return [self[alias] for alias in self.connections_info]
|
|
|
|
|
|
class ConnectionRouter:
|
|
def __init__(self):
|
|
self._routers = None
|
|
|
|
@property
|
|
def routers(self):
|
|
if self._routers is None:
|
|
default_routers = ["haystack.routers.DefaultRouter"]
|
|
router_list = getattr(settings, "HAYSTACK_ROUTERS", default_routers)
|
|
# in case HAYSTACK_ROUTERS is empty, fallback to default routers
|
|
if not len(router_list):
|
|
router_list = default_routers
|
|
|
|
self._routers = []
|
|
for router_path in router_list:
|
|
router_class = load_router(router_path)
|
|
self._routers.append(router_class())
|
|
return self._routers
|
|
|
|
def _for_action(self, action, many, **hints):
|
|
conns = []
|
|
|
|
for router in self.routers:
|
|
if hasattr(router, action):
|
|
action_callable = getattr(router, action)
|
|
connection_to_use = action_callable(**hints)
|
|
|
|
if connection_to_use is not None:
|
|
if isinstance(connection_to_use, str):
|
|
conns.append(connection_to_use)
|
|
else:
|
|
conns.extend(connection_to_use)
|
|
if not many:
|
|
break
|
|
|
|
return conns
|
|
|
|
def for_write(self, **hints):
|
|
return self._for_action("for_write", True, **hints)
|
|
|
|
def for_read(self, **hints):
|
|
return self._for_action("for_read", False, **hints)[0]
|
|
|
|
|
|
class UnifiedIndex:
|
|
# Used to collect all the indexes into a cohesive whole.
|
|
def __init__(self, excluded_indexes=None):
|
|
self._indexes = {}
|
|
self.fields = OrderedDict()
|
|
self._built = False
|
|
self.excluded_indexes = excluded_indexes or []
|
|
self.excluded_indexes_ids = {}
|
|
self.document_field = constants.DOCUMENT_FIELD
|
|
self._fieldnames = {}
|
|
self._facet_fieldnames = {}
|
|
|
|
@property
|
|
def indexes(self):
|
|
warnings.warn(
|
|
"'UnifiedIndex.indexes' was deprecated in Haystack v2.3.0. Please use UnifiedIndex.get_indexes()."
|
|
)
|
|
return self._indexes
|
|
|
|
def collect_indexes(self):
|
|
indexes = []
|
|
|
|
for app_mod in haystack_get_app_modules():
|
|
try:
|
|
search_index_module = importlib.import_module(
|
|
"%s.search_indexes" % app_mod.__name__
|
|
)
|
|
except ImportError:
|
|
if module_has_submodule(app_mod, "search_indexes"):
|
|
raise
|
|
|
|
continue
|
|
|
|
for item_name, item in inspect.getmembers(
|
|
search_index_module, inspect.isclass
|
|
):
|
|
if getattr(item, "haystack_use_for_indexing", False) and getattr(
|
|
item, "get_model", None
|
|
):
|
|
# We've got an index. Check if we should be ignoring it.
|
|
class_path = "%s.search_indexes.%s" % (app_mod.__name__, item_name)
|
|
|
|
if (
|
|
class_path in self.excluded_indexes
|
|
or self.excluded_indexes_ids.get(item_name) == id(item)
|
|
):
|
|
self.excluded_indexes_ids[str(item_name)] = id(item)
|
|
continue
|
|
|
|
indexes.append(item())
|
|
|
|
return indexes
|
|
|
|
def reset(self):
|
|
self._indexes = {}
|
|
self.fields = OrderedDict()
|
|
self._built = False
|
|
self._fieldnames = {}
|
|
self._facet_fieldnames = {}
|
|
|
|
def build(self, indexes=None):
|
|
self.reset()
|
|
|
|
if indexes is None:
|
|
indexes = self.collect_indexes()
|
|
|
|
for index in indexes:
|
|
model = index.get_model()
|
|
|
|
if model in self._indexes:
|
|
raise ImproperlyConfigured(
|
|
"Model '%s' has more than one 'SearchIndex`` handling it. "
|
|
"Please exclude either '%s' or '%s' using the 'EXCLUDED_INDEXES' "
|
|
"setting defined in 'settings.HAYSTACK_CONNECTIONS'."
|
|
% (model, self._indexes[model], index)
|
|
)
|
|
|
|
self._indexes[model] = index
|
|
self.collect_fields(index)
|
|
|
|
self._built = True
|
|
|
|
def collect_fields(self, index):
|
|
for fieldname, field_object in index.fields.items():
|
|
if field_object.document is True:
|
|
if field_object.index_fieldname != self.document_field:
|
|
raise SearchFieldError(
|
|
"All 'SearchIndex' classes must use the same '%s' fieldname for the 'document=True' field. Offending index is '%s'."
|
|
% (self.document_field, index)
|
|
)
|
|
|
|
# Stow the index_fieldname so we don't have to get it the hard way again.
|
|
if (
|
|
fieldname in self._fieldnames
|
|
and field_object.index_fieldname != self._fieldnames[fieldname]
|
|
):
|
|
# We've already seen this field in the list. Raise an exception if index_fieldname differs.
|
|
raise SearchFieldError(
|
|
"All uses of the '%s' field need to use the same 'index_fieldname' attribute."
|
|
% fieldname
|
|
)
|
|
|
|
self._fieldnames[fieldname] = field_object.index_fieldname
|
|
|
|
# Stow the facet_fieldname so we don't have to look that up either.
|
|
if hasattr(field_object, "facet_for"):
|
|
if field_object.facet_for:
|
|
self._facet_fieldnames[field_object.facet_for] = fieldname
|
|
else:
|
|
self._facet_fieldnames[field_object.instance_name] = fieldname
|
|
|
|
# Copy the field in so we've got a unified schema.
|
|
if field_object.index_fieldname not in self.fields:
|
|
self.fields[field_object.index_fieldname] = field_object
|
|
self.fields[field_object.index_fieldname] = copy.copy(field_object)
|
|
else:
|
|
# If the field types are different, we can mostly
|
|
# safely ignore this. The exception is ``MultiValueField``,
|
|
# in which case we'll use it instead, copying over the
|
|
# values.
|
|
if field_object.is_multivalued:
|
|
old_field = self.fields[field_object.index_fieldname]
|
|
self.fields[field_object.index_fieldname] = field_object
|
|
self.fields[field_object.index_fieldname] = copy.copy(field_object)
|
|
|
|
# Switch it so we don't have to dupe the remaining
|
|
# checks.
|
|
field_object = old_field
|
|
|
|
# We've already got this field in the list. Ensure that
|
|
# what we hand back is a superset of all options that
|
|
# affect the schema.
|
|
if field_object.indexed is True:
|
|
self.fields[field_object.index_fieldname].indexed = True
|
|
|
|
if field_object.stored is True:
|
|
self.fields[field_object.index_fieldname].stored = True
|
|
|
|
if field_object.faceted is True:
|
|
self.fields[field_object.index_fieldname].faceted = True
|
|
|
|
if field_object.use_template is True:
|
|
self.fields[field_object.index_fieldname].use_template = True
|
|
|
|
if field_object.null is True:
|
|
self.fields[field_object.index_fieldname].null = True
|
|
|
|
def get_indexes(self):
|
|
if not self._built:
|
|
self.build()
|
|
|
|
return self._indexes
|
|
|
|
def get_indexed_models(self):
|
|
# Ensuring a list here since Python3 will give us an iterator
|
|
return list(self.get_indexes().keys())
|
|
|
|
def get_index_fieldname(self, field):
|
|
if not self._built:
|
|
self.build()
|
|
|
|
return self._fieldnames.get(field) or field
|
|
|
|
def get_index(self, model_klass):
|
|
indexes = self.get_indexes()
|
|
|
|
if model_klass not in indexes:
|
|
raise NotHandled("The model %s is not registered" % model_klass)
|
|
|
|
return indexes[model_klass]
|
|
|
|
def get_facet_fieldname(self, field):
|
|
if not self._built:
|
|
self.build()
|
|
|
|
for fieldname, field_object in self.fields.items():
|
|
if fieldname != field:
|
|
continue
|
|
|
|
if hasattr(field_object, "facet_for"):
|
|
if field_object.facet_for:
|
|
return field_object.facet_for
|
|
else:
|
|
return field_object.instance_name
|
|
else:
|
|
return self._facet_fieldnames.get(field) or field
|
|
|
|
return field
|
|
|
|
def all_searchfields(self):
|
|
if not self._built:
|
|
self.build()
|
|
|
|
return self.fields
|