more functional design principles
Some checks failed
ci/woodpecker/pr/pr Pipeline failed

This commit is contained in:
Madison Grubb
2026-02-17 11:17:52 -05:00
parent 1a566e2d80
commit c8d37c98f4
14 changed files with 357 additions and 321 deletions

View File

@@ -9,8 +9,10 @@ import { registerCleanup } from '../utils/shutdown.js'
import { COT_AUTH_TIMEOUT_MS } from '../utils/constants.js' import { COT_AUTH_TIMEOUT_MS } from '../utils/constants.js'
import { acquire } from '../utils/asyncLock.js' import { acquire } from '../utils/asyncLock.js'
let tcpServer = null const serverState = {
let tlsServer = null tcpServer: null,
tlsServer: null,
}
const relaySet = new Set() const relaySet = new Set()
const allSockets = new Set() const allSockets = new Set()
const socketBuffers = new WeakMap() const socketBuffers = new WeakMap()
@@ -44,22 +46,27 @@ function broadcast(senderSocket, rawMessage) {
} }
} }
const createPreview = (payload) => {
try {
const str = payload.toString('utf8')
if (str.startsWith('<')) {
const s = str.length <= 120 ? str : str.slice(0, 120) + '...'
// eslint-disable-next-line no-control-regex -- sanitize control chars for log preview
return s.replace(/[\u0000-\u0008\v\f\u000E-\u001F]/g, '.')
}
return 'hex:' + payload.subarray(0, Math.min(40, payload.length)).toString('hex')
}
catch {
return 'hex:' + payload.subarray(0, Math.min(40, payload.length)).toString('hex')
}
}
async function processFrame(socket, rawMessage, payload, authenticated) { async function processFrame(socket, rawMessage, payload, authenticated) {
const requireAuth = socket._cotRequireAuth !== false const requireAuth = socket._cotRequireAuth !== false
const debug = socket._cotDebug === true const debug = socket._cotDebug === true
const parsed = parseCotPayload(payload) const parsed = parseCotPayload(payload)
if (debug) { if (debug) {
let preview = payload.length const preview = createPreview(payload)
try {
const str = payload.toString('utf8')
if (str.startsWith('<')) {
const s = str.length <= 120 ? str : str.slice(0, 120) + '...'
// eslint-disable-next-line no-control-regex -- sanitize control chars for log preview
preview = s.replace(/[\u0000-\u0008\v\f\u000E-\u001F]/g, '.')
}
else preview = 'hex:' + payload.subarray(0, Math.min(40, payload.length)).toString('hex')
}
catch { preview = 'hex:' + payload.subarray(0, Math.min(40, payload.length)).toString('hex') }
console.log('[cot] payload length:', payload.length, 'parsed:', parsed ? parsed.type : null, 'preview:', preview) console.log('[cot] payload length:', payload.length, 'parsed:', parsed ? parsed.type : null, 'preview:', preview)
} }
if (!parsed) return if (!parsed) return
@@ -108,10 +115,35 @@ async function processFrame(socket, rawMessage, payload, authenticated) {
} }
} }
const parseFrame = (buf) => {
const takResult = parseTakStreamFrame(buf)
if (takResult) return { result: takResult, frameType: 'tak' }
if (buf[0] === 0x3C) {
const xmlResult = parseTraditionalXmlFrame(buf)
if (xmlResult) return { result: xmlResult, frameType: 'traditional' }
}
return { result: null, frameType: null }
}
const processBufferedData = async (socket, buf, authenticated) => {
if (buf.length === 0) return buf
const { result, frameType } = parseFrame(buf)
if (result && socket._cotDebug) {
console.log('[cot] frame parsed as', frameType, 'bytesConsumed=', result.bytesConsumed)
}
if (!result) return buf
const { payload, bytesConsumed } = result
const rawMessage = buf.subarray(0, bytesConsumed)
await processFrame(socket, rawMessage, payload, authenticated)
if (socket.destroyed) return null
const remainingBuf = buf.subarray(bytesConsumed)
socketBuffers.set(socket, remainingBuf)
return processBufferedData(socket, remainingBuf, authenticated)
}
async function onData(socket, data) { async function onData(socket, data) {
let buf = socketBuffers.get(socket) const existingBuf = socketBuffers.get(socket)
if (!buf) buf = Buffer.alloc(0) const buf = Buffer.concat([existingBuf || Buffer.alloc(0), data])
buf = Buffer.concat([buf, data])
socketBuffers.set(socket, buf) socketBuffers.set(socket, buf)
const authenticated = Boolean(socket._cotAuthenticated) const authenticated = Boolean(socket._cotAuthenticated)
@@ -120,22 +152,7 @@ async function onData(socket, data) {
const hex = buf.subarray(0, Math.min(80, buf.length)).toString('hex') const hex = buf.subarray(0, Math.min(80, buf.length)).toString('hex')
console.log('[cot] first chunk len=', buf.length, 'first bytes (hex):', hex, 'starts with 0xBF:', buf[0] === 0xBF, 'starts with <:', buf[0] === 0x3C) console.log('[cot] first chunk len=', buf.length, 'first bytes (hex):', hex, 'starts with 0xBF:', buf[0] === 0xBF, 'starts with <:', buf[0] === 0x3C)
} }
while (buf.length > 0) { await processBufferedData(socket, buf, authenticated)
let result = parseTakStreamFrame(buf)
let frameType = 'tak'
if (!result && buf[0] === 0x3C) {
result = parseTraditionalXmlFrame(buf)
frameType = 'traditional'
}
if (result && socket._cotDebug) console.log('[cot] frame parsed as', frameType, 'bytesConsumed=', result.bytesConsumed)
if (!result) break
const { payload, bytesConsumed } = result
const rawMessage = buf.subarray(0, bytesConsumed)
await processFrame(socket, rawMessage, payload, authenticated)
if (socket.destroyed) return
buf = buf.subarray(bytesConsumed)
socketBuffers.set(socket, buf)
}
} }
function setupSocket(socket, tls = false) { function setupSocket(socket, tls = false) {
@@ -182,16 +199,16 @@ function startCotServers() {
key: readFileSync(keyPath), key: readFileSync(keyPath),
rejectUnauthorized: false, rejectUnauthorized: false,
} }
tlsServer = createTlsServer(tlsOpts, socket => setupSocket(socket, true)) serverState.tlsServer = createTlsServer(tlsOpts, socket => setupSocket(socket, true))
tlsServer.on('error', err => console.error('[cot] TLS server error:', err?.message)) serverState.tlsServer.on('error', err => console.error('[cot] TLS server error:', err?.message))
tlsServer.listen(port, '0.0.0.0', () => { serverState.tlsServer.listen(port, '0.0.0.0', () => {
console.log('[cot] CoT server listening on 0.0.0.0:' + port + ' (TLS) — use this port in ATAK/iTAK and enable SSL') console.log('[cot] CoT server listening on 0.0.0.0:' + port + ' (TLS) — use this port in ATAK/iTAK and enable SSL')
}) })
} }
else { else {
tcpServer = createTcpServer(socket => setupSocket(socket, false)) serverState.tcpServer = createTcpServer(socket => setupSocket(socket, false))
tcpServer.on('error', err => console.error('[cot] TCP server error:', err?.message)) serverState.tcpServer.on('error', err => console.error('[cot] TCP server error:', err?.message))
tcpServer.listen(port, '0.0.0.0', () => { serverState.tcpServer.listen(port, '0.0.0.0', () => {
console.log('[cot] CoT server listening on 0.0.0.0:' + port + ' (plain TCP) — use this port in ATAK/iTAK with SSL disabled') console.log('[cot] CoT server listening on 0.0.0.0:' + port + ' (plain TCP) — use this port in ATAK/iTAK with SSL disabled')
}) })
} }
@@ -209,7 +226,18 @@ export default defineNitroPlugin((nitroApp) => {
// Start immediately so CoT is up before first request in dev; ready may fire late in some setups. // Start immediately so CoT is up before first request in dev; ready may fire late in some setups.
setImmediate(startCotServers) setImmediate(startCotServers)
registerCleanup(async () => { const cleanupServers = () => {
if (serverState.tcpServer) {
serverState.tcpServer.close()
serverState.tcpServer = null
}
if (serverState.tlsServer) {
serverState.tlsServer.close()
serverState.tlsServer = null
}
}
const cleanupSockets = () => {
for (const s of allSockets) { for (const s of allSockets) {
try { try {
s.destroy() s.destroy()
@@ -220,34 +248,15 @@ export default defineNitroPlugin((nitroApp) => {
} }
allSockets.clear() allSockets.clear()
relaySet.clear() relaySet.clear()
if (tcpServer) { }
tcpServer.close()
tcpServer = null registerCleanup(async () => {
} cleanupSockets()
if (tlsServer) { cleanupServers()
tlsServer.close()
tlsServer = null
}
}) })
nitroApp.hooks.hook('close', async () => { nitroApp.hooks.hook('close', async () => {
for (const s of allSockets) { cleanupSockets()
try { cleanupServers()
s.destroy()
}
catch {
/* ignore */
}
}
allSockets.clear()
relaySet.clear()
if (tcpServer) {
tcpServer.close()
tcpServer = null
}
if (tlsServer) {
tlsServer.close()
tlsServer = null
}
}) })
}) })

View File

@@ -5,6 +5,19 @@
const locks = new Map() const locks = new Map()
/**
* Get or create a queue for a lock key.
* @param {string} lockKey - Lock key
* @returns {Promise<any>} Existing or new queue promise
*/
const getOrCreateQueue = (lockKey) => {
const existingQueue = locks.get(lockKey)
if (existingQueue) return existingQueue
const newQueue = Promise.resolve()
locks.set(lockKey, newQueue)
return newQueue
}
/** /**
* Acquire a lock for a key and execute callback. * Acquire a lock for a key and execute callback.
* Only one callback per key executes at a time. * Only one callback per key executes at a time.
@@ -14,12 +27,7 @@ const locks = new Map()
*/ */
export async function acquire(key, callback) { export async function acquire(key, callback) {
const lockKey = String(key) const lockKey = String(key)
let queue = locks.get(lockKey) const queue = getOrCreateQueue(lockKey)
if (!queue) {
queue = Promise.resolve()
locks.set(lockKey, queue)
}
const next = queue.then(() => callback()).finally(() => { const next = queue.then(() => callback()).finally(() => {
if (locks.get(lockKey) === next) { if (locks.get(lockKey) === next) {

View File

@@ -15,21 +15,20 @@ const TRADITIONAL_DELIMITER = Buffer.from('</event>', 'utf8')
/** /**
* @param {Buffer} buf * @param {Buffer} buf
* @param {number} offset * @param {number} offset
* @param {number} value - Accumulated value
* @param {number} shift - Current bit shift
* @param {number} bytesRead - Bytes consumed so far
* @returns {{ value: number, bytesRead: number }} Decoded varint and bytes consumed. * @returns {{ value: number, bytesRead: number }} Decoded varint and bytes consumed.
*/ */
function readVarint(buf, offset) { function readVarint(buf, offset, value = 0, shift = 0, bytesRead = 0) {
let value = 0 if (offset + bytesRead >= buf.length) return { value, bytesRead }
let shift = 0 const b = buf[offset + bytesRead]
let bytesRead = 0 const newValue = value + ((b & 0x7F) << shift)
while (offset + bytesRead < buf.length) { const newBytesRead = bytesRead + 1
const b = buf[offset + bytesRead] if ((b & 0x80) === 0) return { value: newValue, bytesRead: newBytesRead }
bytesRead += 1 const newShift = shift + 7
value += (b & 0x7F) << shift if (newShift > 28) return { value: 0, bytesRead: 0 }
if ((b & 0x80) === 0) return { value, bytesRead } return readVarint(buf, offset, newValue, newShift, newBytesRead)
shift += 7
if (shift > 28) return { value: 0, bytesRead: 0 }
}
return { value, bytesRead }
} }
/** /**
@@ -127,12 +126,14 @@ export function parseCotPayload(payload) {
const uid = String(event['@_uid'] ?? event.uid ?? '') const uid = String(event['@_uid'] ?? event.uid ?? '')
const eventType = String(event['@_type'] ?? event.type ?? '') const eventType = String(event['@_type'] ?? event.type ?? '')
const point = findInObject(parsed, 'point') ?? findInObject(event, 'point') const point = findInObject(parsed, 'point') ?? findInObject(event, 'point')
let lat = Number.NaN const extractCoords = (pt) => {
let lng = Number.NaN if (!pt || typeof pt !== 'object') return { lat: Number.NaN, lng: Number.NaN }
if (point && typeof point === 'object') { return {
lat = Number(point['@_lat'] ?? point.lat) lat: Number(pt['@_lat'] ?? pt.lat),
lng = Number(point['@_lon'] ?? point.lon ?? point['@_lng'] ?? point.lng) lng: Number(pt['@_lon'] ?? pt.lon ?? pt['@_lng'] ?? pt.lng),
}
} }
const { lat, lng } = extractCoords(point)
if (!Number.isFinite(lat) || !Number.isFinite(lng)) return null if (!Number.isFinite(lat) || !Number.isFinite(lng)) return null
const detail = findInObject(parsed, 'detail') const detail = findInObject(parsed, 'detail')

View File

@@ -5,11 +5,13 @@
import { SHUTDOWN_TIMEOUT_MS } from './constants.js' import { SHUTDOWN_TIMEOUT_MS } from './constants.js'
const cleanupFunctions = [] const cleanupFunctions = []
let isShuttingDown = false const shutdownState = {
isShuttingDown: false,
}
export function clearCleanup() { export function clearCleanup() {
cleanupFunctions.length = 0 cleanupFunctions.length = 0
isShuttingDown = false shutdownState.isShuttingDown = false
} }
export function registerCleanup(fn) { export function registerCleanup(fn) {
@@ -17,17 +19,25 @@ export function registerCleanup(fn) {
cleanupFunctions.push(fn) cleanupFunctions.push(fn)
} }
async function executeCleanup() { const executeCleanupFunction = async (fn, index) => {
if (isShuttingDown) return try {
isShuttingDown = true await fn()
for (let i = cleanupFunctions.length - 1; i >= 0; i--) {
try {
await cleanupFunctions[i]()
}
catch (error) {
console.error(`[shutdown] Cleanup function ${i} failed:`, error?.message || String(error))
}
} }
catch (error) {
console.error(`[shutdown] Cleanup function ${index} failed:`, error?.message || String(error))
}
}
const executeCleanupReverse = async (functions, index = functions.length - 1) => {
if (index < 0) return
await executeCleanupFunction(functions[index], index)
return executeCleanupReverse(functions, index - 1)
}
async function executeCleanup() {
if (shutdownState.isShuttingDown) return
shutdownState.isShuttingDown = true
await executeCleanupReverse(cleanupFunctions)
} }
export async function graceful(error) { export async function graceful(error) {

View File

@@ -1,3 +1,18 @@
/**
* Encode a number as varint bytes (little-endian, continuation bit).
* @param {number} value - Value to encode
* @param {number[]} bytes - Accumulated bytes (default empty)
* @returns {number[]} Varint bytes
*/
const encodeVarint = (value, bytes = []) => {
const byte = value & 0x7F
const remaining = value >>> 7
if (remaining === 0) {
return [...bytes, byte]
}
return encodeVarint(remaining, [...bytes, byte | 0x80])
}
/** /**
* Build a TAK Protocol stream frame: 0xBF, varint payload length, payload. * Build a TAK Protocol stream frame: 0xBF, varint payload length, payload.
* @param {string|Buffer} payload - UTF-8 payload (e.g. CoT XML) * @param {string|Buffer} payload - UTF-8 payload (e.g. CoT XML)
@@ -5,17 +20,7 @@
*/ */
export function buildTakFrame(payload) { export function buildTakFrame(payload) {
const buf = Buffer.isBuffer(payload) ? payload : Buffer.from(payload, 'utf8') const buf = Buffer.isBuffer(payload) ? payload : Buffer.from(payload, 'utf8')
let n = buf.length const varint = encodeVarint(buf.length)
const varint = []
while (true) {
const byte = n & 0x7F
n >>>= 7
if (n === 0) {
varint.push(byte)
break
}
varint.push(byte | 0x80)
}
return Buffer.concat([Buffer.from([0xBF]), Buffer.from(varint), buf]) return Buffer.concat([Buffer.from([0xBF]), Buffer.from(varint), buf])
} }

View File

@@ -7,12 +7,12 @@ describe('asyncLock', () => {
}) })
it('executes callback immediately when no lock exists', async () => { it('executes callback immediately when no lock exists', async () => {
let executed = false const executed = { value: false }
await acquire('test', async () => { await acquire('test', async () => {
executed = true executed.value = true
return 42 return 42
}) })
expect(executed).toBe(true) expect(executed.value).toBe(true)
}) })
it('returns callback result', async () => { it('returns callback result', async () => {
@@ -24,43 +24,35 @@ describe('asyncLock', () => {
it('serializes concurrent operations on same key', async () => { it('serializes concurrent operations on same key', async () => {
const results = [] const results = []
const promises = [] const promises = Array.from({ length: 5 }, (_, i) =>
acquire('same-key', async () => {
for (let i = 0; i < 5; i++) { results.push(`start-${i}`)
promises.push( await new Promise(resolve => setTimeout(resolve, 10))
acquire('same-key', async () => { results.push(`end-${i}`)
results.push(`start-${i}`) return i
await new Promise(resolve => setTimeout(resolve, 10)) }),
results.push(`end-${i}`) )
return i
}),
)
}
await Promise.all(promises) await Promise.all(promises)
// Operations should be serialized: start-end pairs should not interleave // Operations should be serialized: start-end pairs should not interleave
expect(results.length).toBe(10) expect(results.length).toBe(10)
for (let i = 0; i < 5; i++) { Array.from({ length: 5 }, (_, i) => {
expect(results[i * 2]).toBe(`start-${i}`) expect(results[i * 2]).toBe(`start-${i}`)
expect(results[i * 2 + 1]).toBe(`end-${i}`) expect(results[i * 2 + 1]).toBe(`end-${i}`)
} })
}) })
it('allows parallel operations on different keys', async () => { it('allows parallel operations on different keys', async () => {
const results = [] const results = []
const promises = [] const promises = Array.from({ length: 5 }, (_, i) =>
acquire(`key-${i}`, async () => {
for (let i = 0; i < 5; i++) { results.push(`start-${i}`)
promises.push( await new Promise(resolve => setTimeout(resolve, 10))
acquire(`key-${i}`, async () => { results.push(`end-${i}`)
results.push(`start-${i}`) return i
await new Promise(resolve => setTimeout(resolve, 10)) }),
results.push(`end-${i}`) )
return i
}),
)
}
await Promise.all(promises) await Promise.all(promises)
@@ -74,10 +66,10 @@ describe('asyncLock', () => {
}) })
it('handles errors and releases lock', async () => { it('handles errors and releases lock', async () => {
let callCount = 0 const callCount = { value: 0 }
try { try {
await acquire('error-key', async () => { await acquire('error-key', async () => {
callCount++ callCount.value++
throw new Error('Test error') throw new Error('Test error')
}) })
} }
@@ -87,27 +79,22 @@ describe('asyncLock', () => {
// Lock should be released, next operation should execute // Lock should be released, next operation should execute
await acquire('error-key', async () => { await acquire('error-key', async () => {
callCount++ callCount.value++
return 'success' return 'success'
}) })
expect(callCount).toBe(2) expect(callCount.value).toBe(2)
}) })
it('maintains lock ordering', async () => { it('maintains lock ordering', async () => {
const order = [] const order = []
const promises = [] const promises = Array.from({ length: 3 }, (_, i) =>
acquire('ordered', async () => {
for (let i = 0; i < 3; i++) { order.push(`before-${i}`)
const idx = i await new Promise(resolve => setTimeout(resolve, 5))
promises.push( order.push(`after-${i}`)
acquire('ordered', async () => { }),
order.push(`before-${idx}`) )
await new Promise(resolve => setTimeout(resolve, 5))
order.push(`after-${idx}`)
}),
)
}
await Promise.all(promises) await Promise.all(promises)

View File

@@ -1,19 +1,18 @@
import { describe, it, expect } from 'vitest' import { describe, it, expect } from 'vitest'
import { parseTakStreamFrame, parseTraditionalXmlFrame, parseCotPayload } from '../../../server/utils/cotParser.js' import { parseTakStreamFrame, parseTraditionalXmlFrame, parseCotPayload } from '../../../server/utils/cotParser.js'
const encodeVarint = (value, bytes = []) => {
const byte = value & 0x7F
const remaining = value >>> 7
if (remaining === 0) {
return [...bytes, byte]
}
return encodeVarint(remaining, [...bytes, byte | 0x80])
}
function buildTakFrame(payload) { function buildTakFrame(payload) {
const buf = Buffer.from(payload, 'utf8') const buf = Buffer.from(payload, 'utf8')
let n = buf.length const varint = encodeVarint(buf.length)
const varint = []
while (true) {
const byte = n & 0x7F
n >>>= 7
if (n === 0) {
varint.push(byte)
break
}
varint.push(byte | 0x80)
}
return Buffer.concat([Buffer.from([0xBF]), Buffer.from(varint), buf]) return Buffer.concat([Buffer.from([0xBF]), Buffer.from(varint), buf])
} }
@@ -42,14 +41,7 @@ describe('cotParser', () => {
it('returns null for payload length exceeding max', () => { it('returns null for payload length exceeding max', () => {
const hugeLen = 64 * 1024 + 1 const hugeLen = 64 * 1024 + 1
const varint = [] const varint = encodeVarint(hugeLen)
let n = hugeLen
while (true) {
varint.push(n & 0x7F)
n >>>= 7
if (n === 0) break
varint[varint.length - 1] |= 0x80
}
const buf = Buffer.concat([Buffer.from([0xBF]), Buffer.from(varint)]) const buf = Buffer.concat([Buffer.from([0xBF]), Buffer.from(varint)])
expect(parseTakStreamFrame(buf)).toBeNull() expect(parseTakStreamFrame(buf)).toBeNull()
}) })

View File

@@ -13,23 +13,25 @@ import {
import { withTemporaryEnv } from '../helpers/env.js' import { withTemporaryEnv } from '../helpers/env.js'
describe('cotSsl', () => { describe('cotSsl', () => {
let testCertDir const testPaths = {
let testCertPath testCertDir: null,
let testKeyPath testCertPath: null,
testKeyPath: null,
}
beforeEach(() => { beforeEach(() => {
testCertDir = join(tmpdir(), `kestrelos-test-${Date.now()}`) testPaths.testCertDir = join(tmpdir(), `kestrelos-test-${Date.now()}`)
mkdirSync(testCertDir, { recursive: true }) mkdirSync(testPaths.testCertDir, { recursive: true })
testCertPath = join(testCertDir, 'cert.pem') testPaths.testCertPath = join(testPaths.testCertDir, 'cert.pem')
testKeyPath = join(testCertDir, 'key.pem') testPaths.testKeyPath = join(testPaths.testCertDir, 'key.pem')
writeFileSync(testCertPath, '-----BEGIN CERTIFICATE-----\nTEST\n-----END CERTIFICATE-----\n') writeFileSync(testPaths.testCertPath, '-----BEGIN CERTIFICATE-----\nTEST\n-----END CERTIFICATE-----\n')
writeFileSync(testKeyPath, '-----BEGIN PRIVATE KEY-----\nTEST\n-----END PRIVATE KEY-----\n') writeFileSync(testPaths.testKeyPath, '-----BEGIN PRIVATE KEY-----\nTEST\n-----END PRIVATE KEY-----\n')
}) })
afterEach(() => { afterEach(() => {
try { try {
if (existsSync(testCertPath)) unlinkSync(testCertPath) if (existsSync(testPaths.testCertPath)) unlinkSync(testPaths.testCertPath)
if (existsSync(testKeyPath)) unlinkSync(testKeyPath) if (existsSync(testPaths.testKeyPath)) unlinkSync(testPaths.testKeyPath)
} }
catch { catch {
// Ignore cleanup errors // Ignore cleanup errors
@@ -78,22 +80,22 @@ describe('cotSsl', () => {
}) })
it('returns paths from COT_SSL_CERT and COT_SSL_KEY env vars', () => { it('returns paths from COT_SSL_CERT and COT_SSL_KEY env vars', () => {
withTemporaryEnv({ COT_SSL_CERT: testCertPath, COT_SSL_KEY: testKeyPath }, () => { withTemporaryEnv({ COT_SSL_CERT: testPaths.testCertPath, COT_SSL_KEY: testPaths.testKeyPath }, () => {
expect(getCotSslPaths()).toEqual({ certPath: testCertPath, keyPath: testKeyPath }) expect(getCotSslPaths()).toEqual({ certPath: testPaths.testCertPath, keyPath: testPaths.testKeyPath })
}) })
}) })
it('returns paths from config parameter when env vars not set', () => { it('returns paths from config parameter when env vars not set', () => {
withTemporaryEnv({ COT_SSL_CERT: undefined, COT_SSL_KEY: undefined }, () => { withTemporaryEnv({ COT_SSL_CERT: undefined, COT_SSL_KEY: undefined }, () => {
const config = { cotSslCert: testCertPath, cotSslKey: testKeyPath } const config = { cotSslCert: testPaths.testCertPath, cotSslKey: testPaths.testKeyPath }
expect(getCotSslPaths(config)).toEqual({ certPath: testCertPath, keyPath: testKeyPath }) expect(getCotSslPaths(config)).toEqual({ certPath: testPaths.testCertPath, keyPath: testPaths.testKeyPath })
}) })
}) })
it('prefers env vars over config parameter', () => { it('prefers env vars over config parameter', () => {
withTemporaryEnv({ COT_SSL_CERT: testCertPath, COT_SSL_KEY: testKeyPath }, () => { withTemporaryEnv({ COT_SSL_CERT: testPaths.testCertPath, COT_SSL_KEY: testPaths.testKeyPath }, () => {
const config = { cotSslCert: '/other/cert.pem', cotSslKey: '/other/key.pem' } const config = { cotSslCert: '/other/cert.pem', cotSslKey: '/other/key.pem' }
expect(getCotSslPaths(config)).toEqual({ certPath: testCertPath, keyPath: testKeyPath }) expect(getCotSslPaths(config)).toEqual({ certPath: testPaths.testCertPath, keyPath: testPaths.testKeyPath })
}) })
}) })
@@ -113,7 +115,7 @@ describe('cotSsl', () => {
}) })
it('throws error when openssl command fails', () => { it('throws error when openssl command fails', () => {
const invalidCertPath = join(testCertDir, 'invalid.pem') const invalidCertPath = join(testPaths.testCertDir, 'invalid.pem')
writeFileSync(invalidCertPath, 'invalid cert content') writeFileSync(invalidCertPath, 'invalid cert content')
expect(() => { expect(() => {
buildP12FromCertPath(invalidCertPath, 'password') buildP12FromCertPath(invalidCertPath, 'password')
@@ -121,7 +123,7 @@ describe('cotSsl', () => {
}) })
it('cleans up temp file on error', () => { it('cleans up temp file on error', () => {
const invalidCertPath = join(testCertDir, 'invalid.pem') const invalidCertPath = join(testPaths.testCertDir, 'invalid.pem')
writeFileSync(invalidCertPath, 'invalid cert content') writeFileSync(invalidCertPath, 'invalid cert content')
try { try {
buildP12FromCertPath(invalidCertPath, 'password') buildP12FromCertPath(invalidCertPath, 'password')

View File

@@ -17,18 +17,20 @@ vi.mock('../../../server/utils/mediasoup.js', () => ({
})) }))
describe('liveSessions', () => { describe('liveSessions', () => {
let sessionId const testState = {
sessionId: null,
}
beforeEach(async () => { beforeEach(async () => {
clearSessions() clearSessions()
const session = await createSession('test-user', 'Test Session') const session = await createSession('test-user', 'Test Session')
sessionId = session.id testState.sessionId = session.id
}) })
it('creates a session with WebRTC fields', () => { it('creates a session with WebRTC fields', () => {
const session = getLiveSession(sessionId) const session = getLiveSession(testState.sessionId)
expect(session).toBeDefined() expect(session).toBeDefined()
expect(session.id).toBe(sessionId) expect(session.id).toBe(testState.sessionId)
expect(session.userId).toBe('test-user') expect(session.userId).toBe('test-user')
expect(session.label).toBe('Test Session') expect(session.label).toBe('Test Session')
expect(session.routerId).toBeNull() expect(session.routerId).toBeNull()
@@ -37,45 +39,45 @@ describe('liveSessions', () => {
}) })
it('updates location', async () => { it('updates location', async () => {
await updateLiveSession(sessionId, { lat: 37.7, lng: -122.4 }) await updateLiveSession(testState.sessionId, { lat: 37.7, lng: -122.4 })
const session = getLiveSession(sessionId) const session = getLiveSession(testState.sessionId)
expect(session.lat).toBe(37.7) expect(session.lat).toBe(37.7)
expect(session.lng).toBe(-122.4) expect(session.lng).toBe(-122.4)
}) })
it('updates WebRTC fields', async () => { it('updates WebRTC fields', async () => {
await updateLiveSession(sessionId, { routerId: 'router-1', producerId: 'producer-1', transportId: 'transport-1' }) await updateLiveSession(testState.sessionId, { routerId: 'router-1', producerId: 'producer-1', transportId: 'transport-1' })
const session = getLiveSession(sessionId) const session = getLiveSession(testState.sessionId)
expect(session.routerId).toBe('router-1') expect(session.routerId).toBe('router-1')
expect(session.producerId).toBe('producer-1') expect(session.producerId).toBe('producer-1')
expect(session.transportId).toBe('transport-1') expect(session.transportId).toBe('transport-1')
}) })
it('returns hasStream instead of hasSnapshot', async () => { it('returns hasStream instead of hasSnapshot', async () => {
await updateLiveSession(sessionId, { producerId: 'producer-1' }) await updateLiveSession(testState.sessionId, { producerId: 'producer-1' })
const active = await getActiveSessions() const active = await getActiveSessions()
const session = active.find(s => s.id === sessionId) const session = active.find(s => s.id === testState.sessionId)
expect(session).toBeDefined() expect(session).toBeDefined()
expect(session.hasStream).toBe(true) expect(session.hasStream).toBe(true)
}) })
it('returns hasStream false when no producer', async () => { it('returns hasStream false when no producer', async () => {
const active = await getActiveSessions() const active = await getActiveSessions()
const session = active.find(s => s.id === sessionId) const session = active.find(s => s.id === testState.sessionId)
expect(session).toBeDefined() expect(session).toBeDefined()
expect(session.hasStream).toBe(false) expect(session.hasStream).toBe(false)
}) })
it('deletes a session', async () => { it('deletes a session', async () => {
await deleteLiveSession(sessionId) await deleteLiveSession(testState.sessionId)
const session = getLiveSession(sessionId) const session = getLiveSession(testState.sessionId)
expect(session).toBeUndefined() expect(session).toBeUndefined()
}) })
it('getActiveSessionByUserId returns session for same user when active', async () => { it('getActiveSessionByUserId returns session for same user when active', async () => {
const found = await getActiveSessionByUserId('test-user') const found = await getActiveSessionByUserId('test-user')
expect(found).toBeDefined() expect(found).toBeDefined()
expect(found.id).toBe(sessionId) expect(found.id).toBe(testState.sessionId)
}) })
it('getActiveSessionByUserId returns undefined for unknown user', async () => { it('getActiveSessionByUserId returns undefined for unknown user', async () => {
@@ -84,18 +86,18 @@ describe('liveSessions', () => {
}) })
it('getActiveSessionByUserId returns undefined for expired session', async () => { it('getActiveSessionByUserId returns undefined for expired session', async () => {
const session = getLiveSession(sessionId) const session = getLiveSession(testState.sessionId)
session.updatedAt = Date.now() - 120_000 session.updatedAt = Date.now() - 120_000
const found = await getActiveSessionByUserId('test-user') const found = await getActiveSessionByUserId('test-user')
expect(found).toBeUndefined() expect(found).toBeUndefined()
}) })
it('getActiveSessions removes expired sessions', async () => { it('getActiveSessions removes expired sessions', async () => {
const session = getLiveSession(sessionId) const session = getLiveSession(testState.sessionId)
session.updatedAt = Date.now() - 120_000 session.updatedAt = Date.now() - 120_000
const active = await getActiveSessions() const active = await getActiveSessions()
expect(active.find(s => s.id === sessionId)).toBeUndefined() expect(active.find(s => s.id === testState.sessionId)).toBeUndefined()
expect(getLiveSession(sessionId)).toBeUndefined() expect(getLiveSession(testState.sessionId)).toBeUndefined()
}) })
it('getActiveSessions runs cleanup for expired session with producer and transport', async () => { it('getActiveSessions runs cleanup for expired session with producer and transport', async () => {
@@ -105,19 +107,19 @@ describe('liveSessions', () => {
getProducer.mockReturnValue(mockProducer) getProducer.mockReturnValue(mockProducer)
getTransport.mockReturnValue(mockTransport) getTransport.mockReturnValue(mockTransport)
closeRouter.mockResolvedValue(undefined) closeRouter.mockResolvedValue(undefined)
await updateLiveSession(sessionId, { producerId: 'p1', transportId: 't1', routerId: 'r1' }) await updateLiveSession(testState.sessionId, { producerId: 'p1', transportId: 't1', routerId: 'r1' })
const session = getLiveSession(sessionId) const session = getLiveSession(testState.sessionId)
session.updatedAt = Date.now() - 120_000 session.updatedAt = Date.now() - 120_000
const active = await getActiveSessions() const active = await getActiveSessions()
expect(active.find(s => s.id === sessionId)).toBeUndefined() expect(active.find(s => s.id === testState.sessionId)).toBeUndefined()
expect(mockProducer.close).toHaveBeenCalled() expect(mockProducer.close).toHaveBeenCalled()
expect(mockTransport.close).toHaveBeenCalled() expect(mockTransport.close).toHaveBeenCalled()
expect(closeRouter).toHaveBeenCalledWith(sessionId) expect(closeRouter).toHaveBeenCalledWith(testState.sessionId)
}) })
it('getOrCreateSession returns existing active session', async () => { it('getOrCreateSession returns existing active session', async () => {
const session = await getOrCreateSession('test-user', 'New Label') const session = await getOrCreateSession('test-user', 'New Label')
expect(session.id).toBe(sessionId) expect(session.id).toBe(testState.sessionId)
expect(session.userId).toBe('test-user') expect(session.userId).toBe('test-user')
}) })
@@ -128,10 +130,9 @@ describe('liveSessions', () => {
}) })
it('getOrCreateSession handles concurrent calls atomically', async () => { it('getOrCreateSession handles concurrent calls atomically', async () => {
const promises = [] const promises = Array.from({ length: 5 }, () =>
for (let i = 0; i < 5; i++) { getOrCreateSession('concurrent-user', 'Concurrent'),
promises.push(getOrCreateSession('concurrent-user', 'Concurrent')) )
}
const sessions = await Promise.all(promises) const sessions = await Promise.all(promises)
const uniqueIds = new Set(sessions.map(s => s.id)) const uniqueIds = new Set(sessions.map(s => s.id))
expect(uniqueIds.size).toBe(1) expect(uniqueIds.size).toBe(1)

View File

@@ -2,41 +2,43 @@ import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'
import { info, error, warn, debug, setContext, clearContext, runWithContext } from '../../server/utils/logger.js' import { info, error, warn, debug, setContext, clearContext, runWithContext } from '../../server/utils/logger.js'
describe('logger', () => { describe('logger', () => {
let originalLog const testState = {
let originalError originalLog: null,
let originalWarn originalError: null,
let originalDebug originalWarn: null,
let logCalls originalDebug: null,
let errorCalls logCalls: [],
let warnCalls errorCalls: [],
let debugCalls warnCalls: [],
debugCalls: [],
}
beforeEach(() => { beforeEach(() => {
logCalls = [] testState.logCalls = []
errorCalls = [] testState.errorCalls = []
warnCalls = [] testState.warnCalls = []
debugCalls = [] testState.debugCalls = []
originalLog = console.log testState.originalLog = console.log
originalError = console.error testState.originalError = console.error
originalWarn = console.warn testState.originalWarn = console.warn
originalDebug = console.debug testState.originalDebug = console.debug
console.log = vi.fn((...args) => logCalls.push(args)) console.log = vi.fn((...args) => testState.logCalls.push(args))
console.error = vi.fn((...args) => errorCalls.push(args)) console.error = vi.fn((...args) => testState.errorCalls.push(args))
console.warn = vi.fn((...args) => warnCalls.push(args)) console.warn = vi.fn((...args) => testState.warnCalls.push(args))
console.debug = vi.fn((...args) => debugCalls.push(args)) console.debug = vi.fn((...args) => testState.debugCalls.push(args))
}) })
afterEach(() => { afterEach(() => {
console.log = originalLog console.log = testState.originalLog
console.error = originalError console.error = testState.originalError
console.warn = originalWarn console.warn = testState.originalWarn
console.debug = originalDebug console.debug = testState.originalDebug
}) })
it('logs info message', () => { it('logs info message', () => {
info('Test message') info('Test message')
expect(logCalls.length).toBe(1) expect(testState.logCalls.length).toBe(1)
const logMsg = logCalls[0][0] const logMsg = testState.logCalls[0][0]
expect(logMsg).toContain('[INFO]') expect(logMsg).toContain('[INFO]')
expect(logMsg).toContain('Test message') expect(logMsg).toContain('Test message')
}) })
@@ -44,7 +46,7 @@ describe('logger', () => {
it('includes request context when set', async () => { it('includes request context when set', async () => {
await runWithContext('req-123', 'user-456', async () => { await runWithContext('req-123', 'user-456', async () => {
info('Test message') info('Test message')
const logMsg = logCalls[0][0] const logMsg = testState.logCalls[0][0]
expect(logMsg).toContain('req-123') expect(logMsg).toContain('req-123')
expect(logMsg).toContain('user-456') expect(logMsg).toContain('user-456')
}) })
@@ -52,7 +54,7 @@ describe('logger', () => {
it('includes additional context', () => { it('includes additional context', () => {
info('Test message', { key: 'value', count: 42 }) info('Test message', { key: 'value', count: 42 })
const logMsg = logCalls[0][0] const logMsg = testState.logCalls[0][0]
expect(logMsg).toContain('key') expect(logMsg).toContain('key')
expect(logMsg).toContain('value') expect(logMsg).toContain('value')
expect(logMsg).toContain('42') expect(logMsg).toContain('42')
@@ -61,8 +63,8 @@ describe('logger', () => {
it('logs error with stack trace', () => { it('logs error with stack trace', () => {
const err = new Error('Test error') const err = new Error('Test error')
error('Failed', { error: err }) error('Failed', { error: err })
expect(errorCalls.length).toBe(1) expect(testState.errorCalls.length).toBe(1)
const errorMsg = errorCalls[0][0] const errorMsg = testState.errorCalls[0][0]
expect(errorMsg).toContain('[ERROR]') expect(errorMsg).toContain('[ERROR]')
expect(errorMsg).toContain('Failed') expect(errorMsg).toContain('Failed')
expect(errorMsg).toContain('stack') expect(errorMsg).toContain('stack')
@@ -70,8 +72,8 @@ describe('logger', () => {
it('logs warning', () => { it('logs warning', () => {
warn('Warning message') warn('Warning message')
expect(warnCalls.length).toBe(1) expect(testState.warnCalls.length).toBe(1)
const warnMsg = warnCalls[0][0] const warnMsg = testState.warnCalls[0][0]
expect(warnMsg).toContain('[WARN]') expect(warnMsg).toContain('[WARN]')
}) })
@@ -79,7 +81,7 @@ describe('logger', () => {
const originalEnv = process.env.NODE_ENV const originalEnv = process.env.NODE_ENV
process.env.NODE_ENV = 'development' process.env.NODE_ENV = 'development'
debug('Debug message') debug('Debug message')
expect(debugCalls.length).toBe(1) expect(testState.debugCalls.length).toBe(1)
process.env.NODE_ENV = originalEnv process.env.NODE_ENV = originalEnv
}) })
@@ -87,19 +89,19 @@ describe('logger', () => {
const originalEnv = process.env.NODE_ENV const originalEnv = process.env.NODE_ENV
process.env.NODE_ENV = 'production' process.env.NODE_ENV = 'production'
debug('Debug message') debug('Debug message')
expect(debugCalls.length).toBe(0) expect(testState.debugCalls.length).toBe(0)
process.env.NODE_ENV = originalEnv process.env.NODE_ENV = originalEnv
}) })
it('clears context', async () => { it('clears context', async () => {
await runWithContext('req-123', 'user-456', async () => { await runWithContext('req-123', 'user-456', async () => {
info('Test with context') info('Test with context')
const logMsg = logCalls[0][0] const logMsg = testState.logCalls[0][0]
expect(logMsg).toContain('req-123') expect(logMsg).toContain('req-123')
}) })
// Context should be cleared after runWithContext completes // Context should be cleared after runWithContext completes
info('Test without context') info('Test without context')
const logMsg = logCalls[logCalls.length - 1][0] const logMsg = testState.logCalls[testState.logCalls.length - 1][0]
expect(logMsg).not.toContain('req-123') expect(logMsg).not.toContain('req-123')
}) })
@@ -107,12 +109,12 @@ describe('logger', () => {
await runWithContext(null, null, async () => { await runWithContext(null, null, async () => {
setContext('req-123', 'user-456') setContext('req-123', 'user-456')
info('Test message') info('Test message')
const logMsg = logCalls[0][0] const logMsg = testState.logCalls[0][0]
expect(logMsg).toContain('req-123') expect(logMsg).toContain('req-123')
expect(logMsg).toContain('user-456') expect(logMsg).toContain('user-456')
clearContext() clearContext()
info('Test after clear') info('Test after clear')
const logMsg2 = logCalls[1][0] const logMsg2 = testState.logCalls[1][0]
expect(logMsg2).not.toContain('req-123') expect(logMsg2).not.toContain('req-123')
}) })
}) })

View File

@@ -3,28 +3,30 @@ import { createSession, deleteLiveSession } from '../../../server/utils/liveSess
import { getRouter, createTransport, closeRouter, getTransport, createProducer, getProducer, createConsumer } from '../../../server/utils/mediasoup.js' import { getRouter, createTransport, closeRouter, getTransport, createProducer, getProducer, createConsumer } from '../../../server/utils/mediasoup.js'
describe('Mediasoup', () => { describe('Mediasoup', () => {
let sessionId const testState = {
sessionId: null,
}
beforeEach(() => { beforeEach(() => {
sessionId = createSession('test-user', 'Test Session').id testState.sessionId = createSession('test-user', 'Test Session').id
}) })
afterEach(async () => { afterEach(async () => {
if (sessionId) { if (testState.sessionId) {
await closeRouter(sessionId) await closeRouter(testState.sessionId)
deleteLiveSession(sessionId) deleteLiveSession(testState.sessionId)
} }
}) })
it('should create a router for a session', async () => { it('should create a router for a session', async () => {
const router = await getRouter(sessionId) const router = await getRouter(testState.sessionId)
expect(router).toBeDefined() expect(router).toBeDefined()
expect(router.id).toBeDefined() expect(router.id).toBeDefined()
expect(router.rtpCapabilities).toBeDefined() expect(router.rtpCapabilities).toBeDefined()
}) })
it('should create a transport', async () => { it('should create a transport', async () => {
const router = await getRouter(sessionId) const router = await getRouter(testState.sessionId)
const { transport, params } = await createTransport(router) const { transport, params } = await createTransport(router)
expect(transport).toBeDefined() expect(transport).toBeDefined()
expect(params.id).toBe(transport.id) expect(params.id).toBe(transport.id)
@@ -34,7 +36,7 @@ describe('Mediasoup', () => {
}) })
it('should create a transport with requestHost IPv4 and return valid params', async () => { it('should create a transport with requestHost IPv4 and return valid params', async () => {
const router = await getRouter(sessionId) const router = await getRouter(testState.sessionId)
const { transport, params } = await createTransport(router, '192.168.2.100') const { transport, params } = await createTransport(router, '192.168.2.100')
expect(transport).toBeDefined() expect(transport).toBeDefined()
expect(params.id).toBe(transport.id) expect(params.id).toBe(transport.id)
@@ -45,13 +47,13 @@ describe('Mediasoup', () => {
}) })
it('should reuse router for same session', async () => { it('should reuse router for same session', async () => {
const router1 = await getRouter(sessionId) const router1 = await getRouter(testState.sessionId)
const router2 = await getRouter(sessionId) const router2 = await getRouter(testState.sessionId)
expect(router1.id).toBe(router2.id) expect(router1.id).toBe(router2.id)
}) })
it('should get transport by ID', async () => { it('should get transport by ID', async () => {
const router = await getRouter(sessionId) const router = await getRouter(testState.sessionId)
const { transport } = await createTransport(router, true) const { transport } = await createTransport(router, true)
const retrieved = getTransport(transport.id) const retrieved = getTransport(transport.id)
expect(retrieved).toBe(transport) expect(retrieved).toBe(transport)
@@ -59,7 +61,7 @@ describe('Mediasoup', () => {
it.skip('should create a producer with mock track', async () => { it.skip('should create a producer with mock track', async () => {
// Mediasoup produce() requires a real MediaStreamTrack (native addon); plain mocks fail with "invalid kind" // Mediasoup produce() requires a real MediaStreamTrack (native addon); plain mocks fail with "invalid kind"
const router = await getRouter(sessionId) const router = await getRouter(testState.sessionId)
const { transport } = await createTransport(router, true) const { transport } = await createTransport(router, true)
const mockTrack = { const mockTrack = {
id: 'mock-track-id', id: 'mock-track-id',
@@ -77,24 +79,25 @@ describe('Mediasoup', () => {
it.skip('should cleanup producer on close', async () => { it.skip('should cleanup producer on close', async () => {
// Depends on createProducer which requires real MediaStreamTrack in Node // Depends on createProducer which requires real MediaStreamTrack in Node
const router = await getRouter(sessionId) const router = await getRouter(testState.sessionId)
const { transport } = await createTransport(router, true) const { transport } = await createTransport(router, true)
const mockTrack = { id: 'mock-track-id', kind: 'video', enabled: true, readyState: 'live' } const mockTrack = { id: 'mock-track-id', kind: 'video', enabled: true, readyState: 'live' }
const producer = await createProducer(transport, mockTrack) const producer = await createProducer(transport, mockTrack)
const producerId = producer.id const producerId = producer.id
expect(getProducer(producerId)).toBe(producer) expect(getProducer(producerId)).toBe(producer)
producer.close() producer.close()
let attempts = 0 const waitForCleanup = async (maxAttempts = 50) => {
while (getProducer(producerId) && attempts < 50) { if (maxAttempts <= 0 || !getProducer(producerId)) return
await new Promise(resolve => setTimeout(resolve, 10)) await new Promise(resolve => setTimeout(resolve, 10))
attempts++ return waitForCleanup(maxAttempts - 1)
} }
await waitForCleanup()
expect(getProducer(producerId) || producer.closed).toBeTruthy() expect(getProducer(producerId) || producer.closed).toBeTruthy()
}) })
it.skip('should create a consumer', async () => { it.skip('should create a consumer', async () => {
// Depends on createProducer which requires real MediaStreamTrack in Node // Depends on createProducer which requires real MediaStreamTrack in Node
const router = await getRouter(sessionId) const router = await getRouter(testState.sessionId)
const { transport } = await createTransport(router, true) const { transport } = await createTransport(router, true)
const mockTrack = { id: 'mock-track-id', kind: 'video', enabled: true, readyState: 'live' } const mockTrack = { id: 'mock-track-id', kind: 'video', enabled: true, readyState: 'live' }
const producer = await createProducer(transport, mockTrack) const producer = await createProducer(transport, mockTrack)
@@ -110,7 +113,7 @@ describe('Mediasoup', () => {
}) })
it('should cleanup transport on close', async () => { it('should cleanup transport on close', async () => {
const router = await getRouter(sessionId) const router = await getRouter(testState.sessionId)
const { transport } = await createTransport(router, true) const { transport } = await createTransport(router, true)
const transportId = transport.id const transportId = transport.id
expect(getTransport(transportId)).toBe(transport) expect(getTransport(transportId)).toBe(transport)
@@ -118,19 +121,20 @@ describe('Mediasoup', () => {
transport.close() transport.close()
// Wait for async cleanup (mediasoup fires 'close' event asynchronously) // Wait for async cleanup (mediasoup fires 'close' event asynchronously)
// Use a promise that resolves when transport is removed or timeout // Use a promise that resolves when transport is removed or timeout
let attempts = 0 const waitForCleanup = async (maxAttempts = 50) => {
while (getTransport(transportId) && attempts < 50) { if (maxAttempts <= 0 || !getTransport(transportId)) return
await new Promise(resolve => setTimeout(resolve, 10)) await new Promise(resolve => setTimeout(resolve, 10))
attempts++ return waitForCleanup(maxAttempts - 1)
} }
await waitForCleanup()
// Transport should be removed from Map (or at least closed) // Transport should be removed from Map (or at least closed)
expect(getTransport(transportId) || transport.closed).toBeTruthy() expect(getTransport(transportId) || transport.closed).toBeTruthy()
}) })
it('should cleanup router on closeRouter', async () => { it('should cleanup router on closeRouter', async () => {
await getRouter(sessionId) await getRouter(testState.sessionId)
await closeRouter(sessionId) await closeRouter(testState.sessionId)
const routerAfter = await getRouter(sessionId) const routerAfter = await getRouter(testState.sessionId)
// New router should have different ID (or same if cached, but old one should be closed) // New router should have different ID (or same if cached, but old one should be closed)
// This test verifies closeRouter doesn't throw // This test verifies closeRouter doesn't throw
expect(routerAfter).toBeDefined() expect(routerAfter).toBeDefined()

View File

@@ -28,12 +28,23 @@ function getRelativeImports(content) {
const paths = [] const paths = []
const fromRegex = /from\s+['"]([^'"]+)['"]/g const fromRegex = /from\s+['"]([^'"]+)['"]/g
const requireRegex = /require\s*\(\s*['"]([^'"]+)['"]\s*\)/g const requireRegex = /require\s*\(\s*['"]([^'"]+)['"]\s*\)/g
for (const re of [fromRegex, requireRegex]) { const extractMatches = (regex, text) => {
let m const matches = []
while ((m = re.exec(content)) !== null) { const execRegex = (r) => {
const p = m[1] const match = r.exec(text)
if (p.startsWith('.')) paths.push(p) if (match) {
matches.push(match[1])
return execRegex(r)
}
return matches
} }
return execRegex(regex)
}
for (const re of [fromRegex, requireRegex]) {
const matches = extractMatches(re, content)
matches.forEach((p) => {
if (p.startsWith('.')) paths.push(p)
})
} }
return paths return paths
} }

View File

@@ -2,20 +2,22 @@ import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'
import { registerCleanup, graceful, clearCleanup, initShutdownHandlers } from '../../server/utils/shutdown.js' import { registerCleanup, graceful, clearCleanup, initShutdownHandlers } from '../../server/utils/shutdown.js'
describe('shutdown', () => { describe('shutdown', () => {
let originalExit const testState = {
let exitCalls originalExit: null,
exitCalls: [],
}
beforeEach(() => { beforeEach(() => {
clearCleanup() clearCleanup()
exitCalls = [] testState.exitCalls = []
originalExit = process.exit testState.originalExit = process.exit
process.exit = vi.fn((code) => { process.exit = vi.fn((code) => {
exitCalls.push(code) testState.exitCalls.push(code)
}) })
}) })
afterEach(() => { afterEach(() => {
process.exit = originalExit process.exit = testState.originalExit
clearCleanup() clearCleanup()
}) })
@@ -46,7 +48,7 @@ describe('shutdown', () => {
await graceful() await graceful()
expect(calls).toEqual(['third', 'second', 'first']) expect(calls).toEqual(['third', 'second', 'first'])
expect(exitCalls).toEqual([0]) expect(testState.exitCalls).toEqual([0])
}) })
it('handles cleanup function errors gracefully', async () => { it('handles cleanup function errors gracefully', async () => {
@@ -59,26 +61,26 @@ describe('shutdown', () => {
await graceful() await graceful()
expect(exitCalls).toEqual([0]) expect(testState.exitCalls).toEqual([0])
}) })
it('exits with code 1 on error', async () => { it('exits with code 1 on error', async () => {
const error = new Error('Test error') const error = new Error('Test error')
await graceful(error) await graceful(error)
expect(exitCalls).toEqual([1]) expect(testState.exitCalls).toEqual([1])
}) })
it('prevents multiple shutdowns', async () => { it('prevents multiple shutdowns', async () => {
let callCount = 0 const callCount = { value: 0 }
registerCleanup(async () => { registerCleanup(async () => {
callCount++ callCount.value++
}) })
await graceful() await graceful()
await graceful() await graceful()
expect(callCount).toBe(1) expect(callCount.value).toBe(1)
}) })
it('handles cleanup error during graceful shutdown', async () => { it('handles cleanup error during graceful shutdown', async () => {
@@ -88,7 +90,7 @@ describe('shutdown', () => {
await graceful() await graceful()
expect(exitCalls).toEqual([0]) expect(testState.exitCalls).toEqual([0])
}) })
it('handles error in executeCleanup catch block', async () => { it('handles error in executeCleanup catch block', async () => {
@@ -98,20 +100,20 @@ describe('shutdown', () => {
await graceful() await graceful()
expect(exitCalls.length).toBeGreaterThan(0) expect(testState.exitCalls.length).toBeGreaterThan(0)
}) })
it('handles error with stack trace', async () => { it('handles error with stack trace', async () => {
const error = new Error('Test error') const error = new Error('Test error')
error.stack = 'Error: Test error\n at test.js:1:1' error.stack = 'Error: Test error\n at test.js:1:1'
await graceful(error) await graceful(error)
expect(exitCalls).toEqual([1]) expect(testState.exitCalls).toEqual([1])
}) })
it('handles error without stack trace', async () => { it('handles error without stack trace', async () => {
const error = { message: 'Test error' } const error = { message: 'Test error' }
await graceful(error) await graceful(error)
expect(exitCalls).toEqual([1]) expect(testState.exitCalls).toEqual([1])
}) })
it('handles timeout scenario', async () => { it('handles timeout scenario', async () => {
@@ -119,7 +121,7 @@ describe('shutdown', () => {
await new Promise(resolve => setTimeout(resolve, 40000)) await new Promise(resolve => setTimeout(resolve, 40000))
}) })
const timeout = setTimeout(() => { const timeout = setTimeout(() => {
expect(exitCalls.length).toBeGreaterThan(0) expect(testState.exitCalls.length).toBeGreaterThan(0)
}, 35000) }, 35000)
graceful() graceful()
await new Promise(resolve => setTimeout(resolve, 100)) await new Promise(resolve => setTimeout(resolve, 100))
@@ -130,7 +132,7 @@ describe('shutdown', () => {
registerCleanup(async () => {}) registerCleanup(async () => {})
await graceful() await graceful()
await graceful() // Second call should return early await graceful() // Second call should return early
expect(exitCalls.length).toBeGreaterThan(0) expect(testState.exitCalls.length).toBeGreaterThan(0)
}) })
it('covers initShutdownHandlers', () => { it('covers initShutdownHandlers', () => {
@@ -151,7 +153,7 @@ describe('shutdown', () => {
throw new Error('Force error in cleanup') throw new Error('Force error in cleanup')
}) })
await graceful() await graceful()
expect(exitCalls.length).toBeGreaterThan(0) expect(testState.exitCalls.length).toBeGreaterThan(0)
}) })
it('covers graceful catch block when executeCleanup throws', async () => { it('covers graceful catch block when executeCleanup throws', async () => {
@@ -176,7 +178,7 @@ describe('shutdown', () => {
await graceful() await graceful()
// Should exit successfully (code 0) because executeCleanup handles errors internally // Should exit successfully (code 0) because executeCleanup handles errors internally
expect(exitCalls).toContain(0) expect(testState.exitCalls).toContain(0)
expect(clearTimeoutCalls.length).toBeGreaterThan(0) expect(clearTimeoutCalls.length).toBeGreaterThan(0)
global.clearTimeout = originalClearTimeout global.clearTimeout = originalClearTimeout
}) })
@@ -213,7 +215,7 @@ describe('shutdown', () => {
await new Promise(resolve => setTimeout(resolve, 10)) await new Promise(resolve => setTimeout(resolve, 10))
expect(errorLogs.some(log => log.includes('Error in graceful shutdown'))).toBe(true) expect(errorLogs.some(log => log.includes('Error in graceful shutdown'))).toBe(true)
expect(exitCalls).toContain(1) expect(testState.exitCalls).toContain(1)
process.on = originalOn process.on = originalOn
process.exit = originalExit process.exit = originalExit

View File

@@ -19,13 +19,15 @@ vi.mock('../../server/utils/mediasoup.js', () => {
}) })
describe('webrtcSignaling', () => { describe('webrtcSignaling', () => {
let sessionId const testState = {
sessionId: null,
}
const userId = 'test-user' const userId = 'test-user'
beforeEach(async () => { beforeEach(async () => {
clearSessions() clearSessions()
const session = await createSession(userId, 'Test') const session = await createSession(userId, 'Test')
sessionId = session.id testState.sessionId = session.id
}) })
it('returns error when session not found', async () => { it('returns error when session not found', async () => {
@@ -34,35 +36,35 @@ describe('webrtcSignaling', () => {
}) })
it('returns Forbidden when userId does not match session', async () => { it('returns Forbidden when userId does not match session', async () => {
const res = await handleWebSocketMessage('other-user', sessionId, 'create-transport', {}) const res = await handleWebSocketMessage('other-user', testState.sessionId, 'create-transport', {})
expect(res).toEqual({ error: 'Forbidden' }) expect(res).toEqual({ error: 'Forbidden' })
}) })
it('returns error for unknown message type', async () => { it('returns error for unknown message type', async () => {
const res = await handleWebSocketMessage(userId, sessionId, 'unknown-type', {}) const res = await handleWebSocketMessage(userId, testState.sessionId, 'unknown-type', {})
expect(res).toEqual({ error: 'Unknown message type: unknown-type' }) expect(res).toEqual({ error: 'Unknown message type: unknown-type' })
}) })
it('returns transportId and dtlsParameters required for connect-transport', async () => { it('returns transportId and dtlsParameters required for connect-transport', async () => {
const res = await handleWebSocketMessage(userId, sessionId, 'connect-transport', {}) const res = await handleWebSocketMessage(userId, testState.sessionId, 'connect-transport', {})
expect(res?.error).toContain('transportId') expect(res?.error).toContain('transportId')
}) })
it('get-router-rtp-capabilities returns router RTP capabilities', async () => { it('get-router-rtp-capabilities returns router RTP capabilities', async () => {
const res = await handleWebSocketMessage(userId, sessionId, 'get-router-rtp-capabilities', {}) const res = await handleWebSocketMessage(userId, testState.sessionId, 'get-router-rtp-capabilities', {})
expect(res?.type).toBe('router-rtp-capabilities') expect(res?.type).toBe('router-rtp-capabilities')
expect(res?.data).toEqual({ codecs: [] }) expect(res?.data).toEqual({ codecs: [] })
}) })
it('create-transport returns transport params', async () => { it('create-transport returns transport params', async () => {
const res = await handleWebSocketMessage(userId, sessionId, 'create-transport', {}) const res = await handleWebSocketMessage(userId, testState.sessionId, 'create-transport', {})
expect(res?.type).toBe('transport-created') expect(res?.type).toBe('transport-created')
expect(res?.data).toBeDefined() expect(res?.data).toBeDefined()
}) })
it('connect-transport connects with valid params', async () => { it('connect-transport connects with valid params', async () => {
await handleWebSocketMessage(userId, sessionId, 'create-transport', {}) await handleWebSocketMessage(userId, testState.sessionId, 'create-transport', {})
const res = await handleWebSocketMessage(userId, sessionId, 'connect-transport', { const res = await handleWebSocketMessage(userId, testState.sessionId, 'connect-transport', {
transportId: 'mock-transport', transportId: 'mock-transport',
dtlsParameters: { role: 'client', fingerprints: [] }, dtlsParameters: { role: 'client', fingerprints: [] },
}) })
@@ -76,8 +78,8 @@ describe('webrtcSignaling', () => {
id: 'mock-transport', id: 'mock-transport',
connect: vi.fn().mockRejectedValue(new Error('Connection failed')), connect: vi.fn().mockRejectedValue(new Error('Connection failed')),
}) })
await handleWebSocketMessage(userId, sessionId, 'create-transport', {}) await handleWebSocketMessage(userId, testState.sessionId, 'create-transport', {})
const res = await handleWebSocketMessage(userId, sessionId, 'connect-transport', { const res = await handleWebSocketMessage(userId, testState.sessionId, 'connect-transport', {
transportId: 'mock-transport', transportId: 'mock-transport',
dtlsParameters: { role: 'client', fingerprints: [] }, dtlsParameters: { role: 'client', fingerprints: [] },
}) })