@ -15,6 +15,8 @@ try:
except ImportError :
from Cookie import SimpleCookie # Py 2
import logging
import tornado
from tornado import web
from tornado import websocket
@ -26,29 +28,35 @@ from .handlers import IPythonHandler
class ZMQStreamHandler ( websocket . WebSocketHandler ) :
def same_origin ( self ) :
""" Check to see that origin and host match in the headers. """
# The difference between version 8 and 13 is that in 8 the
# client sends a "Sec-Websocket-Origin" header and in 13 it's
# simply "Origin".
if self . request . headers . get ( " Sec-WebSocket-Version " ) in ( " 7 " , " 8 " ) :
origin_header = self . request . headers . get ( " Sec-Websocket-Origin " )
else :
origin_header = self . request . headers . get ( " Origin " )
def check_origin ( self , origin ) :
""" Check Origin == Host or CORS origins. """
if self . cors_origin == ' * ' :
return True
host = self . request . headers . get ( " Host " )
# If no header is provided, assume we can't verify origin
if ( origin_header is None or host is None ) :
if ( origin is None or host is None ) :
return False
host_origin = " {0} :// {1} " . format ( self . request . protocol , host )
# OK if origin matches host
if origin == host_origin :
return True
# Check CORS headers
if self . cors_origin :
if self . cors_origin == ' * ' :
return True
else :
return self . cors_origin == origin
elif self . cors_origin_pat :
return bool ( self . cors_origin_pat . match ( origin ) )
else :
# No CORS headers, deny the request
return False
parsed_origin = urlparse ( origin_header )
origin = parsed_origin . netloc
# Check to see that origin matches host directly, including ports
return origin == host
def clear_cookie ( self , * args , * * kwargs ) :
""" meaningless for websockets """
@ -96,13 +104,21 @@ class ZMQStreamHandler(websocket.WebSocketHandler):
class AuthenticatedZMQStreamHandler ( ZMQStreamHandler , IPythonHandler ) :
def set_default_headers ( self ) :
""" Undo the set_default_headers in IPythonHandler
which doesn ' t make sense for websockets
"""
pass
def open ( self , kernel_id ) :
self . kernel_id = cast_unicode ( kernel_id , ' ascii ' )
# Check to see that origin matches host directly, including ports
if not self . same_origin ( ) :
self . log . warn ( " Cross Origin WebSocket Attempt. " )
raise web . HTTPError ( 404 )
# Tornado 4 already does CORS checking
if tornado . version_info [ 0 ] < 4 :
if not self . check_origin ( self . get_origin ( ) ) :
self . log . warn ( " Cross Origin WebSocket Attempt. " )
raise web . HTTPError ( 404 )
self . session = Session ( config = self . config )
self . save_on_message = self . on_message