diff --git a/notebook/base/zmqhandlers.py b/notebook/base/zmqhandlers.py index 6a7b5bdcb..1cfa8ecb3 100644 --- a/notebook/base/zmqhandlers.py +++ b/notebook/base/zmqhandlers.py @@ -218,16 +218,23 @@ class ZMQStreamHandler(WebSocketMixin, WebSocketHandler): self.stream.close() - def _reserialize_reply(self, msg_list, channel=None): + def _reserialize_reply(self, msg_or_list, channel=None): """Reserialize a reply message using JSON. - This takes the msg list from the ZMQ socket, deserializes it using - self.session and then serializes the result using JSON. This method - should be used by self._on_zmq_reply to build messages that can + msg_or_list can be an already-deserialized msg dict or the zmq buffer list. + If it is the zmq list, it will be deserialized with self.session. + + This takes the msg list from the ZMQ socket and serializes the result for the websocket. + This method should be used by self._on_zmq_reply to build messages that can be sent back to the browser. + """ - idents, msg_list = self.session.feed_identities(msg_list) - msg = self.session.deserialize(msg_list) + if isinstance(msg_or_list, dict): + # already unpacked + msg = msg_or_list + else: + idents, msg_list = self.session.feed_identities(msg_or_list) + msg = self.session.deserialize(msg_list) if channel: msg['channel'] = channel if msg['buffers']: diff --git a/notebook/notebookapp.py b/notebook/notebookapp.py index c68700c6f..ef5a3a683 100644 --- a/notebook/notebookapp.py +++ b/notebook/notebookapp.py @@ -75,7 +75,7 @@ from jupyter_client.session import Session from nbformat.sign import NotebookNotary from traitlets import ( Dict, Unicode, Integer, List, Bool, Bytes, Instance, - TraitError, Type, + TraitError, Type, Float ) from ipython_genutils import py3compat from jupyter_core.paths import jupyter_runtime_dir, jupyter_path @@ -185,7 +185,12 @@ class NotebookWebApplication(web.Application): }, version_hash=version_hash, ignore_minified_js=ipython_app.ignore_minified_js, - + + # rate limits + iopub_msg_rate_limit=ipython_app.iopub_msg_rate_limit, + iopub_data_rate_limit=ipython_app.iopub_data_rate_limit, + rate_limit_window=ipython_app.rate_limit_window, + # authentication cookie_secret=ipython_app.cookie_secret, login_url=url_path_join(base_url,'/login'), @@ -788,9 +793,20 @@ class NotebookApp(JupyterApp): help="Reraise exceptions encountered loading server extensions?", ) + iopub_msg_rate_limit = Float(0, config=True, help="""(msg/sec) + Maximum rate at which messages can be sent on iopub before they are + limited.""") + + iopub_data_rate_limit = Float(0, config=True, help="""(bytes/sec) + Maximum rate at which messages can be sent on iopub before they are + limited.""") + + rate_limit_window = Float(1.0, config=True, help="""(sec) Time window used to + check the message and data rate limits.""") + def parse_command_line(self, argv=None): super(NotebookApp, self).parse_command_line(argv) - + if self.extra_args: arg0 = self.extra_args[0] f = os.path.abspath(arg0) diff --git a/notebook/services/kernels/handlers.py b/notebook/services/kernels/handlers.py index 0da36714f..aa5b31979 100644 --- a/notebook/services/kernels/handlers.py +++ b/notebook/services/kernels/handlers.py @@ -96,14 +96,26 @@ class KernelActionHandler(APIHandler): class ZMQChannelsHandler(AuthenticatedZMQStreamHandler): - + @property def kernel_info_timeout(self): return self.settings.get('kernel_info_timeout', 10) - + + @property + def iopub_msg_rate_limit(self): + return self.settings.get('iopub_msg_rate_limit', None) + + @property + def iopub_data_rate_limit(self): + return self.settings.get('iopub_data_rate_limit', None) + + @property + def rate_limit_window(self): + return self.settings.get('rate_limit_window', 1.0) + def __repr__(self): return "%s(%s)" % (self.__class__.__name__, getattr(self, 'kernel_id', 'uninitialized')) - + def create_stream(self): km = self.kernel_manager identity = self.session.bsession @@ -182,7 +194,17 @@ class ZMQChannelsHandler(AuthenticatedZMQStreamHandler): self.kernel_id = None self.kernel_info_channel = None self._kernel_info_future = Future() - + + # Rate limiting code + self._iopub_window_msg_count = 0 + self._iopub_window_byte_count = 0 + self._iopub_msgs_exceeded = False + self._iopub_data_exceeded = False + # Queue of (time stamp, byte count) + # Allows you to specify that the byte count should be lowered + # by a delta amount at some point in the future. + self._iopub_window_byte_queue = [] + @gen.coroutine def pre_get(self): # authenticate first @@ -244,6 +266,88 @@ class ZMQChannelsHandler(AuthenticatedZMQStreamHandler): return stream = self.channels[channel] self.session.send(stream, msg) + + def _on_zmq_reply(self, stream, msg_list): + idents, fed_msg_list = self.session.feed_identities(msg_list) + msg = self.session.deserialize(fed_msg_list) + parent = msg['parent_header'] + def write_stderr(error_message): + self.log.warn(error_message) + msg = self.session.msg("stream", + content={"text": error_message, "name": "stderr"}, + parent=parent + ) + msg['channel'] = 'iopub' + self.write_message(json.dumps(msg, default=date_default)) + + channel = getattr(stream, 'channel', None) + msg_type = msg['header']['msg_type'] + if channel == 'iopub' and msg_type not in {'status', 'comm_open', 'execute_input'}: + + # Remove the counts queued for removal. + now = IOLoop.current().time() + while len(self._iopub_window_byte_queue) > 0: + queued = self._iopub_window_byte_queue[0] + if (now >= queued[0]): + self._iopub_window_byte_count -= queued[1] + self._iopub_window_msg_count -= 1 + del self._iopub_window_byte_queue[0] + else: + # This part of the queue hasn't be reached yet, so we can + # abort the loop. + break + + # Increment the bytes and message count + self._iopub_window_msg_count += 1 + byte_count = sum([len(x) for x in msg_list]) + self._iopub_window_byte_count += byte_count + + # Queue a removal of the byte and message count for a time in the + # future, when we are no longer interested in it. + self._iopub_window_byte_queue.append((now + self.rate_limit_window, byte_count)) + + # Check the limits, set the limit flags, and reset the + # message and data counts. + msg_rate = float(self._iopub_window_msg_count) / self.rate_limit_window + data_rate = float(self._iopub_window_byte_count) / self.rate_limit_window + + # Check the msg rate + if self.iopub_msg_rate_limit is not None and msg_rate > self.iopub_msg_rate_limit and self.iopub_msg_rate_limit > 0: + if not self._iopub_msgs_exceeded: + self._iopub_msgs_exceeded = True + write_stderr("""iopub message rate exceeded. The + notebook server will temporarily stop sending iopub + messages to the client in order to avoid crashing it. + To change this limit, set the config variable + `--NotebookApp.iopub_msg_rate_limit`.""") + return + else: + if self._iopub_msgs_exceeded: + self._iopub_msgs_exceeded = False + if not self._iopub_data_exceeded: + self.log.warn("iopub messages resumed") + + # Check the data rate + if self.iopub_data_rate_limit is not None and data_rate > self.iopub_data_rate_limit and self.iopub_data_rate_limit > 0: + if not self._iopub_data_exceeded: + self._iopub_data_exceeded = True + write_stderr("""iopub data rate exceeded. The + notebook server will temporarily stop sending iopub + messages to the client in order to avoid crashing it. + To change this limit, set the config variable + `--NotebookApp.iopub_data_rate_limit`.""") + return + else: + if self._iopub_data_exceeded: + self._iopub_data_exceeded = False + if not self._iopub_msgs_exceeded: + self.log.warn("iopub messages resumed") + + # If either of the limit flags are set, do not send the message. + if self._iopub_msgs_exceeded or self._iopub_data_exceeded: + return + super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg) + def on_close(self): km = self.kernel_manager