📄 sessionStore.ts  •  4414 bytes
/**
 * CmdCode 向量记忆系统 - 会话与消息管理 + FTS5搜索
 */
import { getDb } from './database'
import { sanitizeFTS5Query, generateId } from './utils'
import { t } from '../i18n.js'

export interface SessionInfo {
  id: string
  title: string
  created_at: string
  updated_at: string
  message_count?: number
}

export interface MessageInfo {
  id: number
  session_id: string
  role: string
  content: string
  created_at: string
}

/** 创建新会话 */
export function createSession(title?: string): SessionInfo {
  const db = getDb()
  const id = generateId('sess')

  db.run('INSERT INTO sessions (id, title) VALUES (?, ?)', id, title || '新会话')

  return { id, title: title || '新会话', created_at: new Date().toISOString(), updated_at: new Date().toISOString() }
}

/** 获取会话 */
export function getSession(sessionId: string): SessionInfo | null {
  const db = getDb()
  const result = db.query('SELECT * FROM sessions WHERE id = ?').get(sessionId) as any
  return result || null
}

/** 列出所有会话 */
export function listSessions(limit = 50): SessionInfo[] {
  const db = getDb()
  const results = db.query(`
    SELECT s.*, COUNT(m.id) as message_count 
    FROM sessions s 
    LEFT JOIN messages m ON m.session_id = s.id 
    GROUP BY s.id 
    ORDER BY s.updated_at DESC 
    LIMIT ?
  `).all(limit) as SessionInfo[]
  return results || []
}

/** 删除会话 */
export function deleteSession(sessionId: string): boolean {
  const db = getDb()
  const result = db.run('DELETE FROM sessions WHERE id = ?', sessionId)
  return result.changes > 0
}

/** 添加消息 */
export function addMessage(sessionId: string, role: string, content: string): MessageInfo {
  const db = getDb()
  const info = db.query('SELECT * FROM sessions WHERE id = ?').get(sessionId) as any
  
  db.run('INSERT INTO messages (session_id, role, content) VALUES (?, ?, ?)', sessionId, role, content)
  
  const lastId = db.query('SELECT last_insert_rowid() as id').get() as any
  
  return {
    id: lastId?.id || 0,
    session_id: sessionId,
    role,
    content,
    created_at: new Date().toISOString()
  }
}

/** 获取会话消息 */
export function getSessionMessages(sessionId: string, limit = 100): MessageInfo[] {
  const db = getDb()
  const results = db.query(`
    SELECT * FROM messages 
    WHERE session_id = ? 
    ORDER BY created_at ASC 
    LIMIT ?
  `).all(sessionId, limit) as MessageInfo[]
  return results || []
}

/** FTS5 全文搜索 */
export function searchFTS(query: string, sessionId?: string, limit = 20): MessageInfo[] {
  const db = getDb()
  const ftsQuery = sanitizeFTS5Query(query)

  try {
    let sql: string
    let params: any[]

    if (sessionId) {
      sql = `
        SELECT m.* FROM messages m
        JOIN messages_fts fts ON m.id = fts.rowid
        WHERE messages_fts MATCH ? AND m.session_id = ?
        ORDER BY rank
        LIMIT ?
      `
      params = [ftsQuery, sessionId, limit]
    } else {
      sql = `
        SELECT m.* FROM messages m
        JOIN messages_fts fts ON m.id = fts.rowid
        WHERE messages_fts MATCH ?
        ORDER BY rank
        LIMIT ?
      `
      params = [ftsQuery, limit]
    }

    const results = db.query(sql).all(...params) as MessageInfo[]
    return results || []
  } catch (e) {
    console.error(t('error.fts_search'), e)
    return []
  }
}

/** 记录 Embedding 失败 */
export function recordEmbeddingFailure(msgId: number, error: string): void {
  const db = getDb()
  const existing = db.query('SELECT * FROM embedding_failures WHERE msg_id = ?').get(msgId)
  
  if (existing) {
    db.run(`
      UPDATE embedding_failures SET 
        fail_count = fail_count + 1,
        last_error = ?,
        updated_at = datetime('now')
      WHERE msg_id = ?
    `, error, msgId)
  } else {
    db.run(`
      INSERT INTO embedding_failures (msg_id, fail_count, last_error)
      VALUES (?, 1, ?)
    `, msgId, error)
  }
}

/** 获取需要回填的消息 */
export function getPendingEmbeddings(limit = 50): MessageInfo[] {
  const db = getDb()
  const results = db.query(`
    SELECT m.* FROM messages m
    LEFT JOIN message_embeddings e ON m.id = e.msg_id
    LEFT JOIN embedding_failures f ON m.id = f.msg_id
    WHERE e.msg_id IS NULL 
      AND (f.fail_count IS NULL OR f.fail_count < 3)
    ORDER BY m.created_at ASC
    LIMIT ?
  `).all(limit) as MessageInfo[]
  return results || []
}