Streaming: Rework websocket server initialisation & authentication code (#28631)

This commit is contained in:
Emelia Smith 2024-01-15 11:36:30 +01:00 committed by GitHub
parent e72676e83a
commit 58830be943
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -182,14 +182,74 @@ const CHANNEL_NAMES = [
];
const startServer = async () => {
const pgPool = new pg.Pool(pgConfigFromEnv(process.env));
const server = http.createServer();
const wss = new WebSocket.Server({ noServer: true });
// Set the X-Request-Id header on WebSockets:
wss.on("headers", function onHeaders(headers, req) {
headers.push(`X-Request-Id: ${req.id}`);
});
const app = express();
app.set('trust proxy', process.env.TRUSTED_PROXY_IP ? process.env.TRUSTED_PROXY_IP.split(/(?:\s*,\s*|\s+)/) : 'loopback,uniquelocal');
const pgPool = new pg.Pool(pgConfigFromEnv(process.env));
const server = http.createServer(app);
app.use(cors());
// Handle eventsource & other http requests:
server.on('request', app);
// Handle upgrade requests:
server.on('upgrade', async function handleUpgrade(request, socket, head) {
/** @param {Error} err */
const onSocketError = (err) => {
log.error(`Error with websocket upgrade: ${err}`);
};
socket.on('error', onSocketError);
// Authenticate:
try {
await accountFromRequest(request);
} catch (err) {
log.error(`Error authenticating request: ${err}`);
// Unfortunately for using the on('upgrade') setup, we need to manually
// write a HTTP Response to the Socket to close the connection upgrade
// attempt, so the following code is to handle all of that.
const statusCode = err.status ?? 401;
/** @type {Record<string, string | number>} */
const headers = {
'Connection': 'close',
'Content-Type': 'text/plain',
'Content-Length': 0,
'X-Request-Id': request.id,
// TODO: Send the error message via header so it can be debugged in
// developer tools
};
// Ensure the socket is closed once we've finished writing to it:
socket.once('finish', () => {
socket.destroy();
});
// Write the HTTP response manually:
socket.end(`HTTP/1.1 ${statusCode} ${http.STATUS_CODES[statusCode]}\r\n${Object.keys(headers).map((key) => `${key}: ${headers[key]}`).join('\r\n')}\r\n\r\n`);
return;
}
wss.handleUpgrade(request, socket, head, function done(ws) {
// Remove the error handler:
socket.removeListener('error', onSocketError);
// Start the connection:
wss.emit('connection', ws, request);
});
});
/**
* @type {Object.<string, Array.<function(Object<string, any>): void>>}
*/
@ -360,10 +420,19 @@ const startServer = async () => {
const isInScope = (req, necessaryScopes) =>
req.scopes.some(scope => necessaryScopes.includes(scope));
/**
* @typedef ResolvedAccount
* @property {string} accessTokenId
* @property {string[]} scopes
* @property {string} accountId
* @property {string[]} chosenLanguages
* @property {string} deviceId
*/
/**
* @param {string} token
* @param {any} req
* @returns {Promise.<void>}
* @returns {Promise<ResolvedAccount>}
*/
const accountFromToken = (token, req) => new Promise((resolve, reject) => {
pgPool.connect((err, client, done) => {
@ -394,14 +463,20 @@ const startServer = async () => {
req.chosenLanguages = result.rows[0].chosen_languages;
req.deviceId = result.rows[0].device_id;
resolve();
resolve({
accessTokenId: result.rows[0].id,
scopes: result.rows[0].scopes.split(' '),
accountId: result.rows[0].account_id,
chosenLanguages: result.rows[0].chosen_languages,
deviceId: result.rows[0].device_id
});
});
});
});
/**
* @param {any} req
* @returns {Promise.<void>}
* @returns {Promise<ResolvedAccount>}
*/
const accountFromRequest = (req) => new Promise((resolve, reject) => {
const authorization = req.headers.authorization;
@ -494,25 +569,6 @@ const startServer = async () => {
reject(err);
});
/**
* @param {any} info
* @param {function(boolean, number, string): void} callback
*/
const wsVerifyClient = (info, callback) => {
// When verifying the websockets connection, we no longer pre-emptively
// check OAuth scopes and drop the connection if they're missing. We only
// drop the connection if access without token is not allowed by environment
// variables. OAuth scope checks are moved to the point of subscription
// to a specific stream.
accountFromRequest(info.req).then(() => {
callback(true, undefined, undefined);
}).catch(err => {
log.error(info.req.requestId, err.toString());
callback(false, 401, 'Unauthorized');
});
};
/**
* @typedef SystemMessageHandlers
* @property {function(): void} onKill
@ -944,8 +1000,8 @@ const startServer = async () => {
};
/**
* @param {any} req
* @param {any} ws
* @param {http.IncomingMessage} req
* @param {WebSocket} ws
* @param {string[]} streamName
* @returns {function(string, string): void}
*/
@ -955,7 +1011,9 @@ const startServer = async () => {
return;
}
ws.send(JSON.stringify({ stream: streamName, event, payload }), (err) => {
const message = JSON.stringify({ stream: streamName, event, payload });
ws.send(message, (/** @type {Error} */ err) => {
if (err) {
log.error(req.requestId, `Failed to send to websocket: ${err}`);
}
@ -992,8 +1050,6 @@ const startServer = async () => {
});
});
const wss = new WebSocket.Server({ server, verifyClient: wsVerifyClient });
/**
* @typedef StreamParams
* @property {string} [tag]
@ -1173,8 +1229,8 @@ const startServer = async () => {
/**
* @typedef WebSocketSession
* @property {any} socket
* @property {any} request
* @property {WebSocket} websocket
* @property {http.IncomingMessage} request
* @property {Object.<string, { channelName: string, listener: SubscriptionListener, stopHeartbeat: function(): void }>} subscriptions
*/
@ -1297,7 +1353,11 @@ const startServer = async () => {
}
};
wss.on('connection', (ws, req) => {
/**
* @param {WebSocket & { isAlive: boolean }} ws
* @param {http.IncomingMessage} req
*/
function onConnection(ws, req) {
// Note: url.parse could throw, which would terminate the connection, so we
// increment the connected clients metric straight away when we establish
// the connection, without waiting:
@ -1375,7 +1435,9 @@ const startServer = async () => {
if (location && location.query.stream) {
subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query);
}
});
}
wss.on('connection', onConnection);
setInterval(() => {
wss.clients.forEach(ws => {