diff --git a/server/plugins/cot.js b/server/plugins/cot.js index 99ad83d..ef5eec6 100644 --- a/server/plugins/cot.js +++ b/server/plugins/cot.js @@ -9,8 +9,10 @@ import { registerCleanup } from '../utils/shutdown.js' import { COT_AUTH_TIMEOUT_MS } from '../utils/constants.js' import { acquire } from '../utils/asyncLock.js' -let tcpServer = null -let tlsServer = null +const serverState = { + tcpServer: null, + tlsServer: null, +} const relaySet = new Set() const allSockets = new Set() const socketBuffers = new WeakMap() @@ -44,22 +46,27 @@ function broadcast(senderSocket, rawMessage) { } } +const createPreview = (payload) => { + try { + const str = payload.toString('utf8') + if (str.startsWith('<')) { + const s = str.length <= 120 ? str : str.slice(0, 120) + '...' + // eslint-disable-next-line no-control-regex -- sanitize control chars for log preview + return s.replace(/[\u0000-\u0008\v\f\u000E-\u001F]/g, '.') + } + return 'hex:' + payload.subarray(0, Math.min(40, payload.length)).toString('hex') + } + catch { + return 'hex:' + payload.subarray(0, Math.min(40, payload.length)).toString('hex') + } +} + async function processFrame(socket, rawMessage, payload, authenticated) { const requireAuth = socket._cotRequireAuth !== false const debug = socket._cotDebug === true const parsed = parseCotPayload(payload) if (debug) { - let preview = payload.length - try { - const str = payload.toString('utf8') - if (str.startsWith('<')) { - const s = str.length <= 120 ? str : str.slice(0, 120) + '...' - // eslint-disable-next-line no-control-regex -- sanitize control chars for log preview - preview = s.replace(/[\u0000-\u0008\v\f\u000E-\u001F]/g, '.') - } - else preview = 'hex:' + payload.subarray(0, Math.min(40, payload.length)).toString('hex') - } - catch { preview = 'hex:' + payload.subarray(0, Math.min(40, payload.length)).toString('hex') } + const preview = createPreview(payload) console.log('[cot] payload length:', payload.length, 'parsed:', parsed ? parsed.type : null, 'preview:', preview) } if (!parsed) return @@ -108,10 +115,35 @@ async function processFrame(socket, rawMessage, payload, authenticated) { } } +const parseFrame = (buf) => { + const takResult = parseTakStreamFrame(buf) + if (takResult) return { result: takResult, frameType: 'tak' } + if (buf[0] === 0x3C) { + const xmlResult = parseTraditionalXmlFrame(buf) + if (xmlResult) return { result: xmlResult, frameType: 'traditional' } + } + return { result: null, frameType: null } +} + +const processBufferedData = async (socket, buf, authenticated) => { + if (buf.length === 0) return buf + const { result, frameType } = parseFrame(buf) + if (result && socket._cotDebug) { + console.log('[cot] frame parsed as', frameType, 'bytesConsumed=', result.bytesConsumed) + } + if (!result) return buf + const { payload, bytesConsumed } = result + const rawMessage = buf.subarray(0, bytesConsumed) + await processFrame(socket, rawMessage, payload, authenticated) + if (socket.destroyed) return null + const remainingBuf = buf.subarray(bytesConsumed) + socketBuffers.set(socket, remainingBuf) + return processBufferedData(socket, remainingBuf, authenticated) +} + async function onData(socket, data) { - let buf = socketBuffers.get(socket) - if (!buf) buf = Buffer.alloc(0) - buf = Buffer.concat([buf, data]) + const existingBuf = socketBuffers.get(socket) + const buf = Buffer.concat([existingBuf || Buffer.alloc(0), data]) socketBuffers.set(socket, buf) const authenticated = Boolean(socket._cotAuthenticated) @@ -120,22 +152,7 @@ async function onData(socket, data) { const hex = buf.subarray(0, Math.min(80, buf.length)).toString('hex') console.log('[cot] first chunk len=', buf.length, 'first bytes (hex):', hex, 'starts with 0xBF:', buf[0] === 0xBF, 'starts with <:', buf[0] === 0x3C) } - while (buf.length > 0) { - let result = parseTakStreamFrame(buf) - let frameType = 'tak' - if (!result && buf[0] === 0x3C) { - result = parseTraditionalXmlFrame(buf) - frameType = 'traditional' - } - if (result && socket._cotDebug) console.log('[cot] frame parsed as', frameType, 'bytesConsumed=', result.bytesConsumed) - if (!result) break - const { payload, bytesConsumed } = result - const rawMessage = buf.subarray(0, bytesConsumed) - await processFrame(socket, rawMessage, payload, authenticated) - if (socket.destroyed) return - buf = buf.subarray(bytesConsumed) - socketBuffers.set(socket, buf) - } + await processBufferedData(socket, buf, authenticated) } function setupSocket(socket, tls = false) { @@ -182,16 +199,16 @@ function startCotServers() { key: readFileSync(keyPath), rejectUnauthorized: false, } - tlsServer = createTlsServer(tlsOpts, socket => setupSocket(socket, true)) - tlsServer.on('error', err => console.error('[cot] TLS server error:', err?.message)) - tlsServer.listen(port, '0.0.0.0', () => { + serverState.tlsServer = createTlsServer(tlsOpts, socket => setupSocket(socket, true)) + serverState.tlsServer.on('error', err => console.error('[cot] TLS server error:', err?.message)) + serverState.tlsServer.listen(port, '0.0.0.0', () => { console.log('[cot] CoT server listening on 0.0.0.0:' + port + ' (TLS) — use this port in ATAK/iTAK and enable SSL') }) } else { - tcpServer = createTcpServer(socket => setupSocket(socket, false)) - tcpServer.on('error', err => console.error('[cot] TCP server error:', err?.message)) - tcpServer.listen(port, '0.0.0.0', () => { + serverState.tcpServer = createTcpServer(socket => setupSocket(socket, false)) + serverState.tcpServer.on('error', err => console.error('[cot] TCP server error:', err?.message)) + serverState.tcpServer.listen(port, '0.0.0.0', () => { console.log('[cot] CoT server listening on 0.0.0.0:' + port + ' (plain TCP) — use this port in ATAK/iTAK with SSL disabled') }) } @@ -209,7 +226,18 @@ export default defineNitroPlugin((nitroApp) => { // Start immediately so CoT is up before first request in dev; ready may fire late in some setups. setImmediate(startCotServers) - registerCleanup(async () => { + const cleanupServers = () => { + if (serverState.tcpServer) { + serverState.tcpServer.close() + serverState.tcpServer = null + } + if (serverState.tlsServer) { + serverState.tlsServer.close() + serverState.tlsServer = null + } + } + + const cleanupSockets = () => { for (const s of allSockets) { try { s.destroy() @@ -220,34 +248,15 @@ export default defineNitroPlugin((nitroApp) => { } allSockets.clear() relaySet.clear() - if (tcpServer) { - tcpServer.close() - tcpServer = null - } - if (tlsServer) { - tlsServer.close() - tlsServer = null - } + } + + registerCleanup(async () => { + cleanupSockets() + cleanupServers() }) nitroApp.hooks.hook('close', async () => { - for (const s of allSockets) { - try { - s.destroy() - } - catch { - /* ignore */ - } - } - allSockets.clear() - relaySet.clear() - if (tcpServer) { - tcpServer.close() - tcpServer = null - } - if (tlsServer) { - tlsServer.close() - tlsServer = null - } + cleanupSockets() + cleanupServers() }) }) diff --git a/server/utils/asyncLock.js b/server/utils/asyncLock.js index 3ea130d..32eacf5 100644 --- a/server/utils/asyncLock.js +++ b/server/utils/asyncLock.js @@ -5,6 +5,19 @@ const locks = new Map() +/** + * Get or create a queue for a lock key. + * @param {string} lockKey - Lock key + * @returns {Promise} Existing or new queue promise + */ +const getOrCreateQueue = (lockKey) => { + const existingQueue = locks.get(lockKey) + if (existingQueue) return existingQueue + const newQueue = Promise.resolve() + locks.set(lockKey, newQueue) + return newQueue +} + /** * Acquire a lock for a key and execute callback. * Only one callback per key executes at a time. @@ -14,12 +27,7 @@ const locks = new Map() */ export async function acquire(key, callback) { const lockKey = String(key) - let queue = locks.get(lockKey) - - if (!queue) { - queue = Promise.resolve() - locks.set(lockKey, queue) - } + const queue = getOrCreateQueue(lockKey) const next = queue.then(() => callback()).finally(() => { if (locks.get(lockKey) === next) { diff --git a/server/utils/cotParser.js b/server/utils/cotParser.js index e034794..49b4530 100644 --- a/server/utils/cotParser.js +++ b/server/utils/cotParser.js @@ -15,21 +15,20 @@ const TRADITIONAL_DELIMITER = Buffer.from('', 'utf8') /** * @param {Buffer} buf * @param {number} offset + * @param {number} value - Accumulated value + * @param {number} shift - Current bit shift + * @param {number} bytesRead - Bytes consumed so far * @returns {{ value: number, bytesRead: number }} Decoded varint and bytes consumed. */ -function readVarint(buf, offset) { - let value = 0 - let shift = 0 - let bytesRead = 0 - while (offset + bytesRead < buf.length) { - const b = buf[offset + bytesRead] - bytesRead += 1 - value += (b & 0x7F) << shift - if ((b & 0x80) === 0) return { value, bytesRead } - shift += 7 - if (shift > 28) return { value: 0, bytesRead: 0 } - } - return { value, bytesRead } +function readVarint(buf, offset, value = 0, shift = 0, bytesRead = 0) { + if (offset + bytesRead >= buf.length) return { value, bytesRead } + const b = buf[offset + bytesRead] + const newValue = value + ((b & 0x7F) << shift) + const newBytesRead = bytesRead + 1 + if ((b & 0x80) === 0) return { value: newValue, bytesRead: newBytesRead } + const newShift = shift + 7 + if (newShift > 28) return { value: 0, bytesRead: 0 } + return readVarint(buf, offset, newValue, newShift, newBytesRead) } /** @@ -127,12 +126,14 @@ export function parseCotPayload(payload) { const uid = String(event['@_uid'] ?? event.uid ?? '') const eventType = String(event['@_type'] ?? event.type ?? '') const point = findInObject(parsed, 'point') ?? findInObject(event, 'point') - let lat = Number.NaN - let lng = Number.NaN - if (point && typeof point === 'object') { - lat = Number(point['@_lat'] ?? point.lat) - lng = Number(point['@_lon'] ?? point.lon ?? point['@_lng'] ?? point.lng) + const extractCoords = (pt) => { + if (!pt || typeof pt !== 'object') return { lat: Number.NaN, lng: Number.NaN } + return { + lat: Number(pt['@_lat'] ?? pt.lat), + lng: Number(pt['@_lon'] ?? pt.lon ?? pt['@_lng'] ?? pt.lng), + } } + const { lat, lng } = extractCoords(point) if (!Number.isFinite(lat) || !Number.isFinite(lng)) return null const detail = findInObject(parsed, 'detail') diff --git a/server/utils/shutdown.js b/server/utils/shutdown.js index 9ea7e4e..610695b 100644 --- a/server/utils/shutdown.js +++ b/server/utils/shutdown.js @@ -5,11 +5,13 @@ import { SHUTDOWN_TIMEOUT_MS } from './constants.js' const cleanupFunctions = [] -let isShuttingDown = false +const shutdownState = { + isShuttingDown: false, +} export function clearCleanup() { cleanupFunctions.length = 0 - isShuttingDown = false + shutdownState.isShuttingDown = false } export function registerCleanup(fn) { @@ -17,17 +19,25 @@ export function registerCleanup(fn) { cleanupFunctions.push(fn) } -async function executeCleanup() { - if (isShuttingDown) return - isShuttingDown = true - for (let i = cleanupFunctions.length - 1; i >= 0; i--) { - try { - await cleanupFunctions[i]() - } - catch (error) { - console.error(`[shutdown] Cleanup function ${i} failed:`, error?.message || String(error)) - } +const executeCleanupFunction = async (fn, index) => { + try { + await fn() } + catch (error) { + console.error(`[shutdown] Cleanup function ${index} failed:`, error?.message || String(error)) + } +} + +const executeCleanupReverse = async (functions, index = functions.length - 1) => { + if (index < 0) return + await executeCleanupFunction(functions[index], index) + return executeCleanupReverse(functions, index - 1) +} + +async function executeCleanup() { + if (shutdownState.isShuttingDown) return + shutdownState.isShuttingDown = true + await executeCleanupReverse(cleanupFunctions) } export async function graceful(error) { diff --git a/test/helpers/fakeAtakClient.js b/test/helpers/fakeAtakClient.js index 808e2b8..2807844 100644 --- a/test/helpers/fakeAtakClient.js +++ b/test/helpers/fakeAtakClient.js @@ -1,3 +1,18 @@ +/** + * Encode a number as varint bytes (little-endian, continuation bit). + * @param {number} value - Value to encode + * @param {number[]} bytes - Accumulated bytes (default empty) + * @returns {number[]} Varint bytes + */ +const encodeVarint = (value, bytes = []) => { + const byte = value & 0x7F + const remaining = value >>> 7 + if (remaining === 0) { + return [...bytes, byte] + } + return encodeVarint(remaining, [...bytes, byte | 0x80]) +} + /** * Build a TAK Protocol stream frame: 0xBF, varint payload length, payload. * @param {string|Buffer} payload - UTF-8 payload (e.g. CoT XML) @@ -5,17 +20,7 @@ */ export function buildTakFrame(payload) { const buf = Buffer.isBuffer(payload) ? payload : Buffer.from(payload, 'utf8') - let n = buf.length - const varint = [] - while (true) { - const byte = n & 0x7F - n >>>= 7 - if (n === 0) { - varint.push(byte) - break - } - varint.push(byte | 0x80) - } + const varint = encodeVarint(buf.length) return Buffer.concat([Buffer.from([0xBF]), Buffer.from(varint), buf]) } diff --git a/test/unit/asyncLock.spec.js b/test/unit/asyncLock.spec.js index 353924f..064f1d6 100644 --- a/test/unit/asyncLock.spec.js +++ b/test/unit/asyncLock.spec.js @@ -7,12 +7,12 @@ describe('asyncLock', () => { }) it('executes callback immediately when no lock exists', async () => { - let executed = false + const executed = { value: false } await acquire('test', async () => { - executed = true + executed.value = true return 42 }) - expect(executed).toBe(true) + expect(executed.value).toBe(true) }) it('returns callback result', async () => { @@ -24,43 +24,35 @@ describe('asyncLock', () => { it('serializes concurrent operations on same key', async () => { const results = [] - const promises = [] - - for (let i = 0; i < 5; i++) { - promises.push( - acquire('same-key', async () => { - results.push(`start-${i}`) - await new Promise(resolve => setTimeout(resolve, 10)) - results.push(`end-${i}`) - return i - }), - ) - } + const promises = Array.from({ length: 5 }, (_, i) => + acquire('same-key', async () => { + results.push(`start-${i}`) + await new Promise(resolve => setTimeout(resolve, 10)) + results.push(`end-${i}`) + return i + }), + ) await Promise.all(promises) // Operations should be serialized: start-end pairs should not interleave expect(results.length).toBe(10) - for (let i = 0; i < 5; i++) { + Array.from({ length: 5 }, (_, i) => { expect(results[i * 2]).toBe(`start-${i}`) expect(results[i * 2 + 1]).toBe(`end-${i}`) - } + }) }) it('allows parallel operations on different keys', async () => { const results = [] - const promises = [] - - for (let i = 0; i < 5; i++) { - promises.push( - acquire(`key-${i}`, async () => { - results.push(`start-${i}`) - await new Promise(resolve => setTimeout(resolve, 10)) - results.push(`end-${i}`) - return i - }), - ) - } + const promises = Array.from({ length: 5 }, (_, i) => + acquire(`key-${i}`, async () => { + results.push(`start-${i}`) + await new Promise(resolve => setTimeout(resolve, 10)) + results.push(`end-${i}`) + return i + }), + ) await Promise.all(promises) @@ -74,10 +66,10 @@ describe('asyncLock', () => { }) it('handles errors and releases lock', async () => { - let callCount = 0 + const callCount = { value: 0 } try { await acquire('error-key', async () => { - callCount++ + callCount.value++ throw new Error('Test error') }) } @@ -87,27 +79,22 @@ describe('asyncLock', () => { // Lock should be released, next operation should execute await acquire('error-key', async () => { - callCount++ + callCount.value++ return 'success' }) - expect(callCount).toBe(2) + expect(callCount.value).toBe(2) }) it('maintains lock ordering', async () => { const order = [] - const promises = [] - - for (let i = 0; i < 3; i++) { - const idx = i - promises.push( - acquire('ordered', async () => { - order.push(`before-${idx}`) - await new Promise(resolve => setTimeout(resolve, 5)) - order.push(`after-${idx}`) - }), - ) - } + const promises = Array.from({ length: 3 }, (_, i) => + acquire('ordered', async () => { + order.push(`before-${i}`) + await new Promise(resolve => setTimeout(resolve, 5)) + order.push(`after-${i}`) + }), + ) await Promise.all(promises) diff --git a/test/unit/cotParser.spec.js b/test/unit/cotParser.spec.js index 1e51c5a..3d8bc2d 100644 --- a/test/unit/cotParser.spec.js +++ b/test/unit/cotParser.spec.js @@ -1,19 +1,18 @@ import { describe, it, expect } from 'vitest' import { parseTakStreamFrame, parseTraditionalXmlFrame, parseCotPayload } from '../../../server/utils/cotParser.js' +const encodeVarint = (value, bytes = []) => { + const byte = value & 0x7F + const remaining = value >>> 7 + if (remaining === 0) { + return [...bytes, byte] + } + return encodeVarint(remaining, [...bytes, byte | 0x80]) +} + function buildTakFrame(payload) { const buf = Buffer.from(payload, 'utf8') - let n = buf.length - const varint = [] - while (true) { - const byte = n & 0x7F - n >>>= 7 - if (n === 0) { - varint.push(byte) - break - } - varint.push(byte | 0x80) - } + const varint = encodeVarint(buf.length) return Buffer.concat([Buffer.from([0xBF]), Buffer.from(varint), buf]) } @@ -42,14 +41,7 @@ describe('cotParser', () => { it('returns null for payload length exceeding max', () => { const hugeLen = 64 * 1024 + 1 - const varint = [] - let n = hugeLen - while (true) { - varint.push(n & 0x7F) - n >>>= 7 - if (n === 0) break - varint[varint.length - 1] |= 0x80 - } + const varint = encodeVarint(hugeLen) const buf = Buffer.concat([Buffer.from([0xBF]), Buffer.from(varint)]) expect(parseTakStreamFrame(buf)).toBeNull() }) diff --git a/test/unit/cotSsl.spec.js b/test/unit/cotSsl.spec.js index 3e45097..e8ceef0 100644 --- a/test/unit/cotSsl.spec.js +++ b/test/unit/cotSsl.spec.js @@ -13,23 +13,25 @@ import { import { withTemporaryEnv } from '../helpers/env.js' describe('cotSsl', () => { - let testCertDir - let testCertPath - let testKeyPath + const testPaths = { + testCertDir: null, + testCertPath: null, + testKeyPath: null, + } beforeEach(() => { - testCertDir = join(tmpdir(), `kestrelos-test-${Date.now()}`) - mkdirSync(testCertDir, { recursive: true }) - testCertPath = join(testCertDir, 'cert.pem') - testKeyPath = join(testCertDir, 'key.pem') - writeFileSync(testCertPath, '-----BEGIN CERTIFICATE-----\nTEST\n-----END CERTIFICATE-----\n') - writeFileSync(testKeyPath, '-----BEGIN PRIVATE KEY-----\nTEST\n-----END PRIVATE KEY-----\n') + testPaths.testCertDir = join(tmpdir(), `kestrelos-test-${Date.now()}`) + mkdirSync(testPaths.testCertDir, { recursive: true }) + testPaths.testCertPath = join(testPaths.testCertDir, 'cert.pem') + testPaths.testKeyPath = join(testPaths.testCertDir, 'key.pem') + writeFileSync(testPaths.testCertPath, '-----BEGIN CERTIFICATE-----\nTEST\n-----END CERTIFICATE-----\n') + writeFileSync(testPaths.testKeyPath, '-----BEGIN PRIVATE KEY-----\nTEST\n-----END PRIVATE KEY-----\n') }) afterEach(() => { try { - if (existsSync(testCertPath)) unlinkSync(testCertPath) - if (existsSync(testKeyPath)) unlinkSync(testKeyPath) + if (existsSync(testPaths.testCertPath)) unlinkSync(testPaths.testCertPath) + if (existsSync(testPaths.testKeyPath)) unlinkSync(testPaths.testKeyPath) } catch { // Ignore cleanup errors @@ -78,22 +80,22 @@ describe('cotSsl', () => { }) it('returns paths from COT_SSL_CERT and COT_SSL_KEY env vars', () => { - withTemporaryEnv({ COT_SSL_CERT: testCertPath, COT_SSL_KEY: testKeyPath }, () => { - expect(getCotSslPaths()).toEqual({ certPath: testCertPath, keyPath: testKeyPath }) + withTemporaryEnv({ COT_SSL_CERT: testPaths.testCertPath, COT_SSL_KEY: testPaths.testKeyPath }, () => { + expect(getCotSslPaths()).toEqual({ certPath: testPaths.testCertPath, keyPath: testPaths.testKeyPath }) }) }) it('returns paths from config parameter when env vars not set', () => { withTemporaryEnv({ COT_SSL_CERT: undefined, COT_SSL_KEY: undefined }, () => { - const config = { cotSslCert: testCertPath, cotSslKey: testKeyPath } - expect(getCotSslPaths(config)).toEqual({ certPath: testCertPath, keyPath: testKeyPath }) + const config = { cotSslCert: testPaths.testCertPath, cotSslKey: testPaths.testKeyPath } + expect(getCotSslPaths(config)).toEqual({ certPath: testPaths.testCertPath, keyPath: testPaths.testKeyPath }) }) }) it('prefers env vars over config parameter', () => { - withTemporaryEnv({ COT_SSL_CERT: testCertPath, COT_SSL_KEY: testKeyPath }, () => { + withTemporaryEnv({ COT_SSL_CERT: testPaths.testCertPath, COT_SSL_KEY: testPaths.testKeyPath }, () => { const config = { cotSslCert: '/other/cert.pem', cotSslKey: '/other/key.pem' } - expect(getCotSslPaths(config)).toEqual({ certPath: testCertPath, keyPath: testKeyPath }) + expect(getCotSslPaths(config)).toEqual({ certPath: testPaths.testCertPath, keyPath: testPaths.testKeyPath }) }) }) @@ -113,7 +115,7 @@ describe('cotSsl', () => { }) it('throws error when openssl command fails', () => { - const invalidCertPath = join(testCertDir, 'invalid.pem') + const invalidCertPath = join(testPaths.testCertDir, 'invalid.pem') writeFileSync(invalidCertPath, 'invalid cert content') expect(() => { buildP12FromCertPath(invalidCertPath, 'password') @@ -121,7 +123,7 @@ describe('cotSsl', () => { }) it('cleans up temp file on error', () => { - const invalidCertPath = join(testCertDir, 'invalid.pem') + const invalidCertPath = join(testPaths.testCertDir, 'invalid.pem') writeFileSync(invalidCertPath, 'invalid cert content') try { buildP12FromCertPath(invalidCertPath, 'password') diff --git a/test/unit/liveSessions.spec.js b/test/unit/liveSessions.spec.js index 2f669c0..fd65d43 100644 --- a/test/unit/liveSessions.spec.js +++ b/test/unit/liveSessions.spec.js @@ -17,18 +17,20 @@ vi.mock('../../../server/utils/mediasoup.js', () => ({ })) describe('liveSessions', () => { - let sessionId + const testState = { + sessionId: null, + } beforeEach(async () => { clearSessions() const session = await createSession('test-user', 'Test Session') - sessionId = session.id + testState.sessionId = session.id }) it('creates a session with WebRTC fields', () => { - const session = getLiveSession(sessionId) + const session = getLiveSession(testState.sessionId) expect(session).toBeDefined() - expect(session.id).toBe(sessionId) + expect(session.id).toBe(testState.sessionId) expect(session.userId).toBe('test-user') expect(session.label).toBe('Test Session') expect(session.routerId).toBeNull() @@ -37,45 +39,45 @@ describe('liveSessions', () => { }) it('updates location', async () => { - await updateLiveSession(sessionId, { lat: 37.7, lng: -122.4 }) - const session = getLiveSession(sessionId) + await updateLiveSession(testState.sessionId, { lat: 37.7, lng: -122.4 }) + const session = getLiveSession(testState.sessionId) expect(session.lat).toBe(37.7) expect(session.lng).toBe(-122.4) }) it('updates WebRTC fields', async () => { - await updateLiveSession(sessionId, { routerId: 'router-1', producerId: 'producer-1', transportId: 'transport-1' }) - const session = getLiveSession(sessionId) + await updateLiveSession(testState.sessionId, { routerId: 'router-1', producerId: 'producer-1', transportId: 'transport-1' }) + const session = getLiveSession(testState.sessionId) expect(session.routerId).toBe('router-1') expect(session.producerId).toBe('producer-1') expect(session.transportId).toBe('transport-1') }) it('returns hasStream instead of hasSnapshot', async () => { - await updateLiveSession(sessionId, { producerId: 'producer-1' }) + await updateLiveSession(testState.sessionId, { producerId: 'producer-1' }) const active = await getActiveSessions() - const session = active.find(s => s.id === sessionId) + const session = active.find(s => s.id === testState.sessionId) expect(session).toBeDefined() expect(session.hasStream).toBe(true) }) it('returns hasStream false when no producer', async () => { const active = await getActiveSessions() - const session = active.find(s => s.id === sessionId) + const session = active.find(s => s.id === testState.sessionId) expect(session).toBeDefined() expect(session.hasStream).toBe(false) }) it('deletes a session', async () => { - await deleteLiveSession(sessionId) - const session = getLiveSession(sessionId) + await deleteLiveSession(testState.sessionId) + const session = getLiveSession(testState.sessionId) expect(session).toBeUndefined() }) it('getActiveSessionByUserId returns session for same user when active', async () => { const found = await getActiveSessionByUserId('test-user') expect(found).toBeDefined() - expect(found.id).toBe(sessionId) + expect(found.id).toBe(testState.sessionId) }) it('getActiveSessionByUserId returns undefined for unknown user', async () => { @@ -84,18 +86,18 @@ describe('liveSessions', () => { }) it('getActiveSessionByUserId returns undefined for expired session', async () => { - const session = getLiveSession(sessionId) + const session = getLiveSession(testState.sessionId) session.updatedAt = Date.now() - 120_000 const found = await getActiveSessionByUserId('test-user') expect(found).toBeUndefined() }) it('getActiveSessions removes expired sessions', async () => { - const session = getLiveSession(sessionId) + const session = getLiveSession(testState.sessionId) session.updatedAt = Date.now() - 120_000 const active = await getActiveSessions() - expect(active.find(s => s.id === sessionId)).toBeUndefined() - expect(getLiveSession(sessionId)).toBeUndefined() + expect(active.find(s => s.id === testState.sessionId)).toBeUndefined() + expect(getLiveSession(testState.sessionId)).toBeUndefined() }) it('getActiveSessions runs cleanup for expired session with producer and transport', async () => { @@ -105,19 +107,19 @@ describe('liveSessions', () => { getProducer.mockReturnValue(mockProducer) getTransport.mockReturnValue(mockTransport) closeRouter.mockResolvedValue(undefined) - await updateLiveSession(sessionId, { producerId: 'p1', transportId: 't1', routerId: 'r1' }) - const session = getLiveSession(sessionId) + await updateLiveSession(testState.sessionId, { producerId: 'p1', transportId: 't1', routerId: 'r1' }) + const session = getLiveSession(testState.sessionId) session.updatedAt = Date.now() - 120_000 const active = await getActiveSessions() - expect(active.find(s => s.id === sessionId)).toBeUndefined() + expect(active.find(s => s.id === testState.sessionId)).toBeUndefined() expect(mockProducer.close).toHaveBeenCalled() expect(mockTransport.close).toHaveBeenCalled() - expect(closeRouter).toHaveBeenCalledWith(sessionId) + expect(closeRouter).toHaveBeenCalledWith(testState.sessionId) }) it('getOrCreateSession returns existing active session', async () => { const session = await getOrCreateSession('test-user', 'New Label') - expect(session.id).toBe(sessionId) + expect(session.id).toBe(testState.sessionId) expect(session.userId).toBe('test-user') }) @@ -128,10 +130,9 @@ describe('liveSessions', () => { }) it('getOrCreateSession handles concurrent calls atomically', async () => { - const promises = [] - for (let i = 0; i < 5; i++) { - promises.push(getOrCreateSession('concurrent-user', 'Concurrent')) - } + const promises = Array.from({ length: 5 }, () => + getOrCreateSession('concurrent-user', 'Concurrent'), + ) const sessions = await Promise.all(promises) const uniqueIds = new Set(sessions.map(s => s.id)) expect(uniqueIds.size).toBe(1) diff --git a/test/unit/logger.spec.js b/test/unit/logger.spec.js index b948b12..c051dc6 100644 --- a/test/unit/logger.spec.js +++ b/test/unit/logger.spec.js @@ -2,41 +2,43 @@ import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest' import { info, error, warn, debug, setContext, clearContext, runWithContext } from '../../server/utils/logger.js' describe('logger', () => { - let originalLog - let originalError - let originalWarn - let originalDebug - let logCalls - let errorCalls - let warnCalls - let debugCalls + const testState = { + originalLog: null, + originalError: null, + originalWarn: null, + originalDebug: null, + logCalls: [], + errorCalls: [], + warnCalls: [], + debugCalls: [], + } beforeEach(() => { - logCalls = [] - errorCalls = [] - warnCalls = [] - debugCalls = [] - originalLog = console.log - originalError = console.error - originalWarn = console.warn - originalDebug = console.debug - console.log = vi.fn((...args) => logCalls.push(args)) - console.error = vi.fn((...args) => errorCalls.push(args)) - console.warn = vi.fn((...args) => warnCalls.push(args)) - console.debug = vi.fn((...args) => debugCalls.push(args)) + testState.logCalls = [] + testState.errorCalls = [] + testState.warnCalls = [] + testState.debugCalls = [] + testState.originalLog = console.log + testState.originalError = console.error + testState.originalWarn = console.warn + testState.originalDebug = console.debug + console.log = vi.fn((...args) => testState.logCalls.push(args)) + console.error = vi.fn((...args) => testState.errorCalls.push(args)) + console.warn = vi.fn((...args) => testState.warnCalls.push(args)) + console.debug = vi.fn((...args) => testState.debugCalls.push(args)) }) afterEach(() => { - console.log = originalLog - console.error = originalError - console.warn = originalWarn - console.debug = originalDebug + console.log = testState.originalLog + console.error = testState.originalError + console.warn = testState.originalWarn + console.debug = testState.originalDebug }) it('logs info message', () => { info('Test message') - expect(logCalls.length).toBe(1) - const logMsg = logCalls[0][0] + expect(testState.logCalls.length).toBe(1) + const logMsg = testState.logCalls[0][0] expect(logMsg).toContain('[INFO]') expect(logMsg).toContain('Test message') }) @@ -44,7 +46,7 @@ describe('logger', () => { it('includes request context when set', async () => { await runWithContext('req-123', 'user-456', async () => { info('Test message') - const logMsg = logCalls[0][0] + const logMsg = testState.logCalls[0][0] expect(logMsg).toContain('req-123') expect(logMsg).toContain('user-456') }) @@ -52,7 +54,7 @@ describe('logger', () => { it('includes additional context', () => { info('Test message', { key: 'value', count: 42 }) - const logMsg = logCalls[0][0] + const logMsg = testState.logCalls[0][0] expect(logMsg).toContain('key') expect(logMsg).toContain('value') expect(logMsg).toContain('42') @@ -61,8 +63,8 @@ describe('logger', () => { it('logs error with stack trace', () => { const err = new Error('Test error') error('Failed', { error: err }) - expect(errorCalls.length).toBe(1) - const errorMsg = errorCalls[0][0] + expect(testState.errorCalls.length).toBe(1) + const errorMsg = testState.errorCalls[0][0] expect(errorMsg).toContain('[ERROR]') expect(errorMsg).toContain('Failed') expect(errorMsg).toContain('stack') @@ -70,8 +72,8 @@ describe('logger', () => { it('logs warning', () => { warn('Warning message') - expect(warnCalls.length).toBe(1) - const warnMsg = warnCalls[0][0] + expect(testState.warnCalls.length).toBe(1) + const warnMsg = testState.warnCalls[0][0] expect(warnMsg).toContain('[WARN]') }) @@ -79,7 +81,7 @@ describe('logger', () => { const originalEnv = process.env.NODE_ENV process.env.NODE_ENV = 'development' debug('Debug message') - expect(debugCalls.length).toBe(1) + expect(testState.debugCalls.length).toBe(1) process.env.NODE_ENV = originalEnv }) @@ -87,19 +89,19 @@ describe('logger', () => { const originalEnv = process.env.NODE_ENV process.env.NODE_ENV = 'production' debug('Debug message') - expect(debugCalls.length).toBe(0) + expect(testState.debugCalls.length).toBe(0) process.env.NODE_ENV = originalEnv }) it('clears context', async () => { await runWithContext('req-123', 'user-456', async () => { info('Test with context') - const logMsg = logCalls[0][0] + const logMsg = testState.logCalls[0][0] expect(logMsg).toContain('req-123') }) // Context should be cleared after runWithContext completes info('Test without context') - const logMsg = logCalls[logCalls.length - 1][0] + const logMsg = testState.logCalls[testState.logCalls.length - 1][0] expect(logMsg).not.toContain('req-123') }) @@ -107,12 +109,12 @@ describe('logger', () => { await runWithContext(null, null, async () => { setContext('req-123', 'user-456') info('Test message') - const logMsg = logCalls[0][0] + const logMsg = testState.logCalls[0][0] expect(logMsg).toContain('req-123') expect(logMsg).toContain('user-456') clearContext() info('Test after clear') - const logMsg2 = logCalls[1][0] + const logMsg2 = testState.logCalls[1][0] expect(logMsg2).not.toContain('req-123') }) }) diff --git a/test/unit/mediasoup.spec.js b/test/unit/mediasoup.spec.js index 6906749..c051fb5 100644 --- a/test/unit/mediasoup.spec.js +++ b/test/unit/mediasoup.spec.js @@ -3,28 +3,30 @@ import { createSession, deleteLiveSession } from '../../../server/utils/liveSess import { getRouter, createTransport, closeRouter, getTransport, createProducer, getProducer, createConsumer } from '../../../server/utils/mediasoup.js' describe('Mediasoup', () => { - let sessionId + const testState = { + sessionId: null, + } beforeEach(() => { - sessionId = createSession('test-user', 'Test Session').id + testState.sessionId = createSession('test-user', 'Test Session').id }) afterEach(async () => { - if (sessionId) { - await closeRouter(sessionId) - deleteLiveSession(sessionId) + if (testState.sessionId) { + await closeRouter(testState.sessionId) + deleteLiveSession(testState.sessionId) } }) it('should create a router for a session', async () => { - const router = await getRouter(sessionId) + const router = await getRouter(testState.sessionId) expect(router).toBeDefined() expect(router.id).toBeDefined() expect(router.rtpCapabilities).toBeDefined() }) it('should create a transport', async () => { - const router = await getRouter(sessionId) + const router = await getRouter(testState.sessionId) const { transport, params } = await createTransport(router) expect(transport).toBeDefined() expect(params.id).toBe(transport.id) @@ -34,7 +36,7 @@ describe('Mediasoup', () => { }) it('should create a transport with requestHost IPv4 and return valid params', async () => { - const router = await getRouter(sessionId) + const router = await getRouter(testState.sessionId) const { transport, params } = await createTransport(router, '192.168.2.100') expect(transport).toBeDefined() expect(params.id).toBe(transport.id) @@ -45,13 +47,13 @@ describe('Mediasoup', () => { }) it('should reuse router for same session', async () => { - const router1 = await getRouter(sessionId) - const router2 = await getRouter(sessionId) + const router1 = await getRouter(testState.sessionId) + const router2 = await getRouter(testState.sessionId) expect(router1.id).toBe(router2.id) }) it('should get transport by ID', async () => { - const router = await getRouter(sessionId) + const router = await getRouter(testState.sessionId) const { transport } = await createTransport(router, true) const retrieved = getTransport(transport.id) expect(retrieved).toBe(transport) @@ -59,7 +61,7 @@ describe('Mediasoup', () => { it.skip('should create a producer with mock track', async () => { // Mediasoup produce() requires a real MediaStreamTrack (native addon); plain mocks fail with "invalid kind" - const router = await getRouter(sessionId) + const router = await getRouter(testState.sessionId) const { transport } = await createTransport(router, true) const mockTrack = { id: 'mock-track-id', @@ -77,24 +79,25 @@ describe('Mediasoup', () => { it.skip('should cleanup producer on close', async () => { // Depends on createProducer which requires real MediaStreamTrack in Node - const router = await getRouter(sessionId) + const router = await getRouter(testState.sessionId) const { transport } = await createTransport(router, true) const mockTrack = { id: 'mock-track-id', kind: 'video', enabled: true, readyState: 'live' } const producer = await createProducer(transport, mockTrack) const producerId = producer.id expect(getProducer(producerId)).toBe(producer) producer.close() - let attempts = 0 - while (getProducer(producerId) && attempts < 50) { + const waitForCleanup = async (maxAttempts = 50) => { + if (maxAttempts <= 0 || !getProducer(producerId)) return await new Promise(resolve => setTimeout(resolve, 10)) - attempts++ + return waitForCleanup(maxAttempts - 1) } + await waitForCleanup() expect(getProducer(producerId) || producer.closed).toBeTruthy() }) it.skip('should create a consumer', async () => { // Depends on createProducer which requires real MediaStreamTrack in Node - const router = await getRouter(sessionId) + const router = await getRouter(testState.sessionId) const { transport } = await createTransport(router, true) const mockTrack = { id: 'mock-track-id', kind: 'video', enabled: true, readyState: 'live' } const producer = await createProducer(transport, mockTrack) @@ -110,7 +113,7 @@ describe('Mediasoup', () => { }) it('should cleanup transport on close', async () => { - const router = await getRouter(sessionId) + const router = await getRouter(testState.sessionId) const { transport } = await createTransport(router, true) const transportId = transport.id expect(getTransport(transportId)).toBe(transport) @@ -118,19 +121,20 @@ describe('Mediasoup', () => { transport.close() // Wait for async cleanup (mediasoup fires 'close' event asynchronously) // Use a promise that resolves when transport is removed or timeout - let attempts = 0 - while (getTransport(transportId) && attempts < 50) { + const waitForCleanup = async (maxAttempts = 50) => { + if (maxAttempts <= 0 || !getTransport(transportId)) return await new Promise(resolve => setTimeout(resolve, 10)) - attempts++ + return waitForCleanup(maxAttempts - 1) } + await waitForCleanup() // Transport should be removed from Map (or at least closed) expect(getTransport(transportId) || transport.closed).toBeTruthy() }) it('should cleanup router on closeRouter', async () => { - await getRouter(sessionId) - await closeRouter(sessionId) - const routerAfter = await getRouter(sessionId) + await getRouter(testState.sessionId) + await closeRouter(testState.sessionId) + const routerAfter = await getRouter(testState.sessionId) // New router should have different ID (or same if cached, but old one should be closed) // This test verifies closeRouter doesn't throw expect(routerAfter).toBeDefined() diff --git a/test/unit/server-imports.spec.js b/test/unit/server-imports.spec.js index c99f41d..216388e 100644 --- a/test/unit/server-imports.spec.js +++ b/test/unit/server-imports.spec.js @@ -28,12 +28,23 @@ function getRelativeImports(content) { const paths = [] const fromRegex = /from\s+['"]([^'"]+)['"]/g const requireRegex = /require\s*\(\s*['"]([^'"]+)['"]\s*\)/g - for (const re of [fromRegex, requireRegex]) { - let m - while ((m = re.exec(content)) !== null) { - const p = m[1] - if (p.startsWith('.')) paths.push(p) + const extractMatches = (regex, text) => { + const matches = [] + const execRegex = (r) => { + const match = r.exec(text) + if (match) { + matches.push(match[1]) + return execRegex(r) + } + return matches } + return execRegex(regex) + } + for (const re of [fromRegex, requireRegex]) { + const matches = extractMatches(re, content) + matches.forEach((p) => { + if (p.startsWith('.')) paths.push(p) + }) } return paths } diff --git a/test/unit/shutdown.spec.js b/test/unit/shutdown.spec.js index c205841..f8ca0dc 100644 --- a/test/unit/shutdown.spec.js +++ b/test/unit/shutdown.spec.js @@ -2,20 +2,22 @@ import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest' import { registerCleanup, graceful, clearCleanup, initShutdownHandlers } from '../../server/utils/shutdown.js' describe('shutdown', () => { - let originalExit - let exitCalls + const testState = { + originalExit: null, + exitCalls: [], + } beforeEach(() => { clearCleanup() - exitCalls = [] - originalExit = process.exit + testState.exitCalls = [] + testState.originalExit = process.exit process.exit = vi.fn((code) => { - exitCalls.push(code) + testState.exitCalls.push(code) }) }) afterEach(() => { - process.exit = originalExit + process.exit = testState.originalExit clearCleanup() }) @@ -46,7 +48,7 @@ describe('shutdown', () => { await graceful() expect(calls).toEqual(['third', 'second', 'first']) - expect(exitCalls).toEqual([0]) + expect(testState.exitCalls).toEqual([0]) }) it('handles cleanup function errors gracefully', async () => { @@ -59,26 +61,26 @@ describe('shutdown', () => { await graceful() - expect(exitCalls).toEqual([0]) + expect(testState.exitCalls).toEqual([0]) }) it('exits with code 1 on error', async () => { const error = new Error('Test error') await graceful(error) - expect(exitCalls).toEqual([1]) + expect(testState.exitCalls).toEqual([1]) }) it('prevents multiple shutdowns', async () => { - let callCount = 0 + const callCount = { value: 0 } registerCleanup(async () => { - callCount++ + callCount.value++ }) await graceful() await graceful() - expect(callCount).toBe(1) + expect(callCount.value).toBe(1) }) it('handles cleanup error during graceful shutdown', async () => { @@ -88,7 +90,7 @@ describe('shutdown', () => { await graceful() - expect(exitCalls).toEqual([0]) + expect(testState.exitCalls).toEqual([0]) }) it('handles error in executeCleanup catch block', async () => { @@ -98,20 +100,20 @@ describe('shutdown', () => { await graceful() - expect(exitCalls.length).toBeGreaterThan(0) + expect(testState.exitCalls.length).toBeGreaterThan(0) }) it('handles error with stack trace', async () => { const error = new Error('Test error') error.stack = 'Error: Test error\n at test.js:1:1' await graceful(error) - expect(exitCalls).toEqual([1]) + expect(testState.exitCalls).toEqual([1]) }) it('handles error without stack trace', async () => { const error = { message: 'Test error' } await graceful(error) - expect(exitCalls).toEqual([1]) + expect(testState.exitCalls).toEqual([1]) }) it('handles timeout scenario', async () => { @@ -119,7 +121,7 @@ describe('shutdown', () => { await new Promise(resolve => setTimeout(resolve, 40000)) }) const timeout = setTimeout(() => { - expect(exitCalls.length).toBeGreaterThan(0) + expect(testState.exitCalls.length).toBeGreaterThan(0) }, 35000) graceful() await new Promise(resolve => setTimeout(resolve, 100)) @@ -130,7 +132,7 @@ describe('shutdown', () => { registerCleanup(async () => {}) await graceful() await graceful() // Second call should return early - expect(exitCalls.length).toBeGreaterThan(0) + expect(testState.exitCalls.length).toBeGreaterThan(0) }) it('covers initShutdownHandlers', () => { @@ -151,7 +153,7 @@ describe('shutdown', () => { throw new Error('Force error in cleanup') }) await graceful() - expect(exitCalls.length).toBeGreaterThan(0) + expect(testState.exitCalls.length).toBeGreaterThan(0) }) it('covers graceful catch block when executeCleanup throws', async () => { @@ -176,7 +178,7 @@ describe('shutdown', () => { await graceful() // Should exit successfully (code 0) because executeCleanup handles errors internally - expect(exitCalls).toContain(0) + expect(testState.exitCalls).toContain(0) expect(clearTimeoutCalls.length).toBeGreaterThan(0) global.clearTimeout = originalClearTimeout }) @@ -213,7 +215,7 @@ describe('shutdown', () => { await new Promise(resolve => setTimeout(resolve, 10)) expect(errorLogs.some(log => log.includes('Error in graceful shutdown'))).toBe(true) - expect(exitCalls).toContain(1) + expect(testState.exitCalls).toContain(1) process.on = originalOn process.exit = originalExit diff --git a/test/unit/webrtcSignaling.spec.js b/test/unit/webrtcSignaling.spec.js index 2c754a2..f50316b 100644 --- a/test/unit/webrtcSignaling.spec.js +++ b/test/unit/webrtcSignaling.spec.js @@ -19,13 +19,15 @@ vi.mock('../../server/utils/mediasoup.js', () => { }) describe('webrtcSignaling', () => { - let sessionId + const testState = { + sessionId: null, + } const userId = 'test-user' beforeEach(async () => { clearSessions() const session = await createSession(userId, 'Test') - sessionId = session.id + testState.sessionId = session.id }) it('returns error when session not found', async () => { @@ -34,35 +36,35 @@ describe('webrtcSignaling', () => { }) it('returns Forbidden when userId does not match session', async () => { - const res = await handleWebSocketMessage('other-user', sessionId, 'create-transport', {}) + const res = await handleWebSocketMessage('other-user', testState.sessionId, 'create-transport', {}) expect(res).toEqual({ error: 'Forbidden' }) }) it('returns error for unknown message type', async () => { - const res = await handleWebSocketMessage(userId, sessionId, 'unknown-type', {}) + const res = await handleWebSocketMessage(userId, testState.sessionId, 'unknown-type', {}) expect(res).toEqual({ error: 'Unknown message type: unknown-type' }) }) it('returns transportId and dtlsParameters required for connect-transport', async () => { - const res = await handleWebSocketMessage(userId, sessionId, 'connect-transport', {}) + const res = await handleWebSocketMessage(userId, testState.sessionId, 'connect-transport', {}) expect(res?.error).toContain('transportId') }) it('get-router-rtp-capabilities returns router RTP capabilities', async () => { - const res = await handleWebSocketMessage(userId, sessionId, 'get-router-rtp-capabilities', {}) + const res = await handleWebSocketMessage(userId, testState.sessionId, 'get-router-rtp-capabilities', {}) expect(res?.type).toBe('router-rtp-capabilities') expect(res?.data).toEqual({ codecs: [] }) }) it('create-transport returns transport params', async () => { - const res = await handleWebSocketMessage(userId, sessionId, 'create-transport', {}) + const res = await handleWebSocketMessage(userId, testState.sessionId, 'create-transport', {}) expect(res?.type).toBe('transport-created') expect(res?.data).toBeDefined() }) it('connect-transport connects with valid params', async () => { - await handleWebSocketMessage(userId, sessionId, 'create-transport', {}) - const res = await handleWebSocketMessage(userId, sessionId, 'connect-transport', { + await handleWebSocketMessage(userId, testState.sessionId, 'create-transport', {}) + const res = await handleWebSocketMessage(userId, testState.sessionId, 'connect-transport', { transportId: 'mock-transport', dtlsParameters: { role: 'client', fingerprints: [] }, }) @@ -76,8 +78,8 @@ describe('webrtcSignaling', () => { id: 'mock-transport', connect: vi.fn().mockRejectedValue(new Error('Connection failed')), }) - await handleWebSocketMessage(userId, sessionId, 'create-transport', {}) - const res = await handleWebSocketMessage(userId, sessionId, 'connect-transport', { + await handleWebSocketMessage(userId, testState.sessionId, 'create-transport', {}) + const res = await handleWebSocketMessage(userId, testState.sessionId, 'connect-transport', { transportId: 'mock-transport', dtlsParameters: { role: 'client', fingerprints: [] }, })