📄 vectorSearch.ts • 4179 bytes
/**
* CmdCode 向量记忆系统 - 向量搜索(简化版)
*
* 注:由于 Bun sqlite 暂不支持虚拟表,向量搜索简化为:
* - 存储 embedding blob 到数据库
* - 搜索时使用 FTS5 关键词匹配 + 内存向量计算
*/
import { getDb } from './database'
import { packEmbedding } from './utils'
import { getEmbedding, clearCache } from './embedding'
import { t } from '../i18n.js'
import type { MessageInfo } from './sessionStore'
// 内存中的向量索引(用于快速相似度搜索)
const memoryIndex = new Map<number, number[]>()
const MAX_MEMORY_INDEX = 1000
/** 存储消息向量 */
export async function storeMessageEmbedding(msgId: number, content: string): Promise<boolean> {
try {
const embedding = await getEmbedding(content)
const db = getDb()
const vectorBlob = packEmbedding(embedding)
const textHash = require('crypto').createHash('sha256').update(content).digest('hex')
db.run(`
INSERT OR REPLACE INTO message_embeddings (msg_id, embedding, text_hash, created_at)
VALUES (?, ?, ?, datetime('now'))
`, msgId, vectorBlob, textHash)
// 同时存入内存索引
if (memoryIndex.size >= MAX_MEMORY_INDEX) {
const firstKey = memoryIndex.keys().next().value
memoryIndex.delete(firstKey)
}
memoryIndex.set(msgId, embedding)
return true
} catch (e) {
console.error(t('error.save_vector'), e)
return false
}
}
/** 余弦相似度计算 */
function cosineSimilarity(a: number[], b: number[]): number {
let dot = 0, normA = 0, normB = 0
for (let i = 0; i < a.length; i++) {
dot += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
return dot / (Math.sqrt(normA) * Math.sqrt(normB) + 1e-10)
}
/** 向量相似度搜索(内存计算) */
export async function searchVectors(query: string, sessionId?: string, limit = 20): Promise<(MessageInfo & { distance: number })[]> {
try {
const queryEmbedding = await getEmbedding(query)
const db = getDb()
// 获取该会话的所有消息及其向量
let sql: string
let params: any[]
if (sessionId) {
sql = `
SELECT m.*, e.embedding FROM messages m
LEFT JOIN message_embeddings e ON m.id = e.msg_id
WHERE m.session_id = ? AND e.embedding IS NOT NULL
`
params = [sessionId]
} else {
sql = `
SELECT m.*, e.embedding FROM messages m
LEFT JOIN message_embeddings e ON m.id = e.msg_id
WHERE e.embedding IS NOT NULL
`
params = []
}
const rows = db.query(sql).all(...params) as any[]
// 计算相似度并排序
const results: (MessageInfo & { distance: number })[] = []
for (const row of rows) {
if (row.embedding) {
const embedding = Array.from(new Float32Array(row.embedding.buffer))
const similarity = cosineSimilarity(queryEmbedding, embedding)
results.push({
id: row.id,
session_id: row.session_id,
role: row.role,
content: row.content,
created_at: row.created_at,
distance: 1 - similarity // 转为距离(越小越相似)
})
}
}
// 排序并返回前N个
results.sort((a, b) => a.distance - b.distance)
return results.slice(0, limit)
} catch (e) {
console.error(t('error.vector_search'), e)
return []
}
}
/** 删除消息向量 */
export function deleteMessageVector(msgId: number): boolean {
const db = getDb()
const result = db.run('DELETE FROM message_embeddings WHERE msg_id = ?', msgId)
memoryIndex.delete(msgId)
// P3 #2.7: 删除向量后清除embedding缓存,防止相同内容重写时命中旧缓存
clearCache()
return result.changes > 0
}
/** 获取向量数量 */
export function getVectorCount(): number {
const db = getDb()
const result = db.query('SELECT COUNT(*) as cnt FROM message_embeddings').get() as { cnt: number }
return result?.cnt || 0
}
/** 检查消息是否有向量 */
export function hasVector(msgId: number): boolean {
const db = getDb()
const result = db.query('SELECT 1 FROM message_embeddings WHERE msg_id = ?').get(msgId)
return !!result
}