commit: 9a81be0d3715eb846d940794f8b34cbbe4ba67a5
parent: 5e2c5e95b6fceadc73c22d28f35d2c0f7c1f2601
Author: unarist <m.unarist@gmail.com>
Date: Tue, 30 May 2017 01:20:53 +0900
[RFC] Return 401 for an authentication error on WebSockets (#3411)
* Return 401 for an authentication error on WebSocket
* Use upgradeReq instead of a custom object
Diffstat:
M | streaming/index.js | 87 | ++++++++++++++++++++++++++++++++++++++++++++----------------------------------- |
1 file changed, 48 insertions(+), 39 deletions(-)
diff --git a/streaming/index.js b/streaming/index.js
@@ -95,7 +95,6 @@ const startWorker = (workerId) => {
const app = express();
const pgPool = new pg.Pool(Object.assign(pgConfigs[env], dbUrlToConfig(process.env.DATABASE_URL)));
const server = http.createServer(app);
- const wss = new WebSocket.Server({ server });
const redisNamespace = process.env.REDIS_NAMESPACE || null;
const redisParams = {
@@ -186,14 +185,10 @@ const startWorker = (workerId) => {
});
};
- const authenticationMiddleware = (req, res, next) => {
- if (req.method === 'OPTIONS') {
- next();
- return;
- }
-
- const authorization = req.get('Authorization');
- const accessToken = req.query.access_token;
+ const accountFromRequest = (req, next) => {
+ const authorization = req.headers.authorization;
+ const location = url.parse(req.url, true);
+ const accessToken = location.query.access_token;
if (!authorization && !accessToken) {
const err = new Error('Missing access token');
@@ -208,6 +203,26 @@ const startWorker = (workerId) => {
accountFromToken(token, req, next);
};
+ const wsVerifyClient = (info, cb) => {
+ accountFromRequest(info.req, err => {
+ if (!err) {
+ cb(true, undefined, undefined);
+ } else {
+ log.error(info.req.requestId, err.toString());
+ cb(false, 401, 'Unauthorized');
+ }
+ });
+ };
+
+ const authenticationMiddleware = (req, res, next) => {
+ if (req.method === 'OPTIONS') {
+ next();
+ return;
+ }
+
+ accountFromRequest(req, next);
+ };
+
const errorMiddleware = (err, req, res, next) => {
log.error(req.requestId, err.toString());
res.writeHead(err.statusCode || 500, { 'Content-Type': 'application/json' });
@@ -352,10 +367,12 @@ const startWorker = (workerId) => {
streamFrom(`timeline:hashtag:${req.query.tag}:local`, req, streamToHttp(req, res), streamHttpEnd(req), true);
});
+ const wss = new WebSocket.Server({ server, verifyClient: wsVerifyClient });
+
wss.on('connection', ws => {
- const location = url.parse(ws.upgradeReq.url, true);
- const token = location.query.access_token;
- const req = { requestId: uuid.v4() };
+ const req = ws.upgradeReq;
+ const location = url.parse(req.url, true);
+ req.requestId = uuid.v4();
ws.isAlive = true;
@@ -363,33 +380,25 @@ const startWorker = (workerId) => {
ws.isAlive = true;
});
- accountFromToken(token, req, err => {
- if (err) {
- log.error(req.requestId, err);
- ws.close();
- return;
- }
-
- switch(location.query.stream) {
- case 'user':
- streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws));
- break;
- case 'public':
- streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true);
- break;
- case 'public:local':
- streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true);
- break;
- case 'hashtag':
- streamFrom(`timeline:hashtag:${location.query.tag}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true);
- break;
- case 'hashtag:local':
- streamFrom(`timeline:hashtag:${location.query.tag}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true);
- break;
- default:
- ws.close();
- }
- });
+ switch(location.query.stream) {
+ case 'user':
+ streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws));
+ break;
+ case 'public':
+ streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true);
+ break;
+ case 'public:local':
+ streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true);
+ break;
+ case 'hashtag':
+ streamFrom(`timeline:hashtag:${location.query.tag}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true);
+ break;
+ case 'hashtag:local':
+ streamFrom(`timeline:hashtag:${location.query.tag}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true);
+ break;
+ default:
+ ws.close();
+ }
});
const wsInterval = setInterval(() => {