You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

164 lines
5.4 KiB

"""Base class for notebook tests."""
from __future__ import print_function
import os
import sys
import time
import requests
from contextlib import contextmanager
from threading import Thread, Event
from unittest import TestCase
pjoin = os.path.join
try:
from unittest.mock import patch
except ImportError:
from mock import patch #py2
from tornado.ioloop import IOLoop
import zmq
import jupyter_core.paths
from ..notebookapp import NotebookApp
from ipython_genutils.tempdir import TemporaryDirectory
MAX_WAITTIME = 30 # seconds to wait for notebook server to start
POLL_INTERVAL = 0.1 # time between attempts
# TimeoutError is a builtin on Python 3. This can be removed when we stop
# supporting Python 2.
class TimeoutError(Exception):
pass
class NotebookTestBase(TestCase):
"""A base class for tests that need a running notebook.
This create some empty config and runtime directories
and then starts the notebook server with them.
"""
port = 12341
config = None
# run with a base URL that would be escaped,
# to test that we don't double-escape URLs
url_prefix = '/a%40b/'
@classmethod
def wait_until_alive(cls):
"""Wait for the server to be alive"""
url = cls.base_url() + 'api/contents'
for _ in range(int(MAX_WAITTIME/POLL_INTERVAL)):
try:
requests.get(url)
except Exception as e:
if not cls.notebook_thread.is_alive():
raise RuntimeError("The notebook server failed to start")
time.sleep(POLL_INTERVAL)
else:
return
raise TimeoutError("The notebook server didn't start up correctly.")
@classmethod
def wait_until_dead(cls):
"""Wait for the server process to terminate after shutdown"""
cls.notebook_thread.join(timeout=MAX_WAITTIME)
if cls.notebook_thread.is_alive():
raise TimeoutError("Undead notebook server")
@classmethod
def setup_class(cls):
cls.home_dir = TemporaryDirectory()
data_dir = TemporaryDirectory()
cls.env_patch = patch.dict('os.environ', {
'HOME': cls.home_dir.name,
'PYTHONPATH': os.pathsep.join(sys.path),
'IPYTHONDIR': pjoin(cls.home_dir.name, '.ipython'),
'JUPYTER_DATA_DIR' : data_dir.name
})
cls.env_patch.start()
cls.path_patch = patch.object(jupyter_core.paths, 'SYSTEM_JUPYTER_PATH', [])
cls.path_patch.start()
cls.config_dir = TemporaryDirectory()
cls.data_dir = data_dir
cls.runtime_dir = TemporaryDirectory()
cls.notebook_dir = TemporaryDirectory()
started = Event()
def start_thread():
app = cls.notebook = NotebookApp(
port=cls.port,
port_retries=0,
open_browser=False,
config_dir=cls.config_dir.name,
data_dir=cls.data_dir.name,
runtime_dir=cls.runtime_dir.name,
notebook_dir=cls.notebook_dir.name,
base_url=cls.url_prefix,
config=cls.config,
)
# don't register signal handler during tests
app.init_signal = lambda : None
# clear log handlers and propagate to root for nose to capture it
# needs to be redone after initialize, which reconfigures logging
app.log.propagate = True
app.log.handlers = []
app.initialize(argv=[])
app.log.propagate = True
app.log.handlers = []
loop = IOLoop.current()
loop.add_callback(started.set)
try:
app.start()
finally:
# set the event, so failure to start doesn't cause a hang
started.set()
app.session_manager.close()
cls.notebook_thread = Thread(target=start_thread)
cls.notebook_thread.daemon = True
cls.notebook_thread.start()
started.wait()
cls.wait_until_alive()
@classmethod
def teardown_class(cls):
cls.notebook.stop()
cls.wait_until_dead()
cls.home_dir.cleanup()
cls.config_dir.cleanup()
cls.data_dir.cleanup()
cls.runtime_dir.cleanup()
cls.notebook_dir.cleanup()
cls.env_patch.stop()
cls.path_patch.stop()
# cleanup global zmq Context, to ensure we aren't leaving dangling sockets
def cleanup_zmq():
zmq.Context.instance().term()
t = Thread(target=cleanup_zmq)
t.daemon = True
t.start()
t.join(5) # give it a few seconds to clean up (this should be immediate)
# if term never returned, there's zmq stuff still open somewhere, so shout about it.
if t.is_alive():
raise RuntimeError("Failed to teardown zmq Context, open sockets likely left lying around.")
@classmethod
def base_url(cls):
return 'http://localhost:%i%s' % (cls.port, cls.url_prefix)
@contextmanager
def assert_http_error(status, msg=None):
try:
yield
except requests.HTTPError as e:
real_status = e.response.status_code
assert real_status == status, \
"Expected status %d, got %d" % (status, real_status)
if msg:
assert msg in str(e), e
else:
assert False, "Expected HTTP error status"