You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

208 lines
7.9 KiB

# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
import logging
from ..base.handlers import IPythonHandler
from ..utils import url_path_join
from tornado import gen, web
from tornado.concurrent import Future
from tornado.ioloop import IOLoop
from tornado.websocket import WebSocketHandler, websocket_connect
from tornado.httpclient import HTTPRequest
from tornado.escape import url_escape, json_decode, utf8
from ipython_genutils.py3compat import cast_unicode
from jupyter_client.session import Session
from traitlets.config.configurable import LoggingConfigurable
from .managers import GatewayClient
class WebSocketChannelsHandler(WebSocketHandler, IPythonHandler):
session = None
gateway = None
kernel_id = None
def set_default_headers(self):
"""Undo the set_default_headers in IPythonHandler which doesn't make sense for websockets"""
pass
def get_compression_options(self):
# use deflate compress websocket
return {}
def authenticate(self):
"""Run before finishing the GET request
Extend this method to add logic that should fire before
the websocket finishes completing.
"""
# authenticate the request before opening the websocket
if self.get_current_user() is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)
if self.get_argument('session_id', False):
self.session.session = cast_unicode(self.get_argument('session_id'))
else:
self.log.warning("No session ID specified")
def initialize(self):
self.log.debug("Initializing websocket connection %s", self.request.path)
self.session = Session(config=self.config)
self.gateway = GatewayWebSocketClient(gateway_url=GatewayClient.instance().url)
@gen.coroutine
def get(self, kernel_id, *args, **kwargs):
self.authenticate()
self.kernel_id = cast_unicode(kernel_id, 'ascii')
super(WebSocketChannelsHandler, self).get(kernel_id=kernel_id, *args, **kwargs)
def open(self, kernel_id, *args, **kwargs):
"""Handle web socket connection open to notebook server and delegate to gateway web socket handler """
self.gateway.on_open(
kernel_id=kernel_id,
message_callback=self.write_message,
compression_options=self.get_compression_options()
)
def on_message(self, message):
"""Forward message to gateway web socket handler."""
self.log.debug("Sending message to gateway: {}".format(message))
self.gateway.on_message(message)
def write_message(self, message, binary=False):
"""Send message back to notebook client. This is called via callback from self.gateway._read_messages."""
self.log.debug("Receiving message from gateway: {}".format(message))
if self.ws_connection: # prevent WebSocketClosedError
super(WebSocketChannelsHandler, self).write_message(message, binary=binary)
elif self.log.isEnabledFor(logging.DEBUG):
msg_summary = WebSocketChannelsHandler._get_message_summary(json_decode(utf8(message)))
self.log.debug("Notebook client closed websocket connection - message dropped: {}".format(msg_summary))
def on_close(self):
self.log.debug("Closing websocket connection %s", self.request.path)
self.gateway.on_close()
super(WebSocketChannelsHandler, self).on_close()
@staticmethod
def _get_message_summary(message):
summary = []
message_type = message['msg_type']
summary.append('type: {}'.format(message_type))
if message_type == 'status':
summary.append(', state: {}'.format(message['content']['execution_state']))
elif message_type == 'error':
summary.append(', {}:{}:{}'.format(message['content']['ename'],
message['content']['evalue'],
message['content']['traceback']))
else:
summary.append(', ...') # don't display potentially sensitive data
return ''.join(summary)
class GatewayWebSocketClient(LoggingConfigurable):
"""Proxy web socket connection to a kernel/enterprise gateway."""
def __init__(self, **kwargs):
super(GatewayWebSocketClient, self).__init__(**kwargs)
self.kernel_id = None
self.ws = None
self.ws_future = Future()
self.ws_future_cancelled = False
@gen.coroutine
def _connect(self, kernel_id):
self.kernel_id = kernel_id
ws_url = url_path_join(
GatewayClient.instance().ws_url,
GatewayClient.instance().kernels_endpoint, url_escape(kernel_id), 'channels'
)
self.log.info('Connecting to {}'.format(ws_url))
kwargs = {}
kwargs = GatewayClient.instance().load_connection_args(**kwargs)
request = HTTPRequest(ws_url, **kwargs)
self.ws_future = websocket_connect(request)
self.ws_future.add_done_callback(self._connection_done)
def _connection_done(self, fut):
if not self.ws_future_cancelled: # prevent concurrent.futures._base.CancelledError
self.ws = fut.result()
self.log.debug("Connection is ready: ws: {}".format(self.ws))
else:
self.log.warning("Websocket connection has been cancelled via client disconnect before its establishment. "
"Kernel with ID '{}' may not be terminated on GatewayClient: {}".
format(self.kernel_id, GatewayClient.instance().url))
def _disconnect(self):
if self.ws is not None:
# Close connection
self.ws.close()
elif not self.ws_future.done():
# Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally
self.ws_future.cancel()
self.ws_future_cancelled = True
self.log.debug("_disconnect: ws_future_cancelled: {}".format(self.ws_future_cancelled))
@gen.coroutine
def _read_messages(self, callback):
"""Read messages from gateway server."""
while True:
message = None
if not self.ws_future_cancelled:
try:
message = yield self.ws.read_message()
except Exception as e:
self.log.error("Exception reading message from websocket: {}".format(e)) # , exc_info=True)
if message is None:
break
callback(message) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
else: # ws cancelled - stop reading
break
def on_open(self, kernel_id, message_callback, **kwargs):
"""Web socket connection open against gateway server."""
self._connect(kernel_id)
loop = IOLoop.current()
loop.add_future(
self.ws_future,
lambda future: self._read_messages(message_callback)
)
def on_message(self, message):
"""Send message to gateway server."""
if self.ws is None:
loop = IOLoop.current()
loop.add_future(
self.ws_future,
lambda future: self._write_message(message)
)
else:
self._write_message(message)
def _write_message(self, message):
"""Send message to gateway server."""
try:
if not self.ws_future_cancelled:
self.ws.write_message(message)
except Exception as e:
self.log.error("Exception writing message to websocket: {}".format(e)) # , exc_info=True)
def on_close(self):
"""Web socket closed event."""
self._disconnect()
from ..services.kernels.handlers import _kernel_id_regex
default_handlers = [
(r"/api/kernels/%s/channels" % _kernel_id_regex, WebSocketChannelsHandler),
]