mirror of
https://github.com/voson-wang/toon.git
synced 2026-01-29 15:24:10 +08:00
test(benchmark): overhaul generation
This commit is contained in:
@@ -5,16 +5,83 @@ import process from 'node:process'
|
||||
import * as prompts from '@clack/prompts'
|
||||
import PQueue from 'p-queue'
|
||||
import { BENCHMARKS_DIR, DEFAULT_CONCURRENCY, DRY_RUN, DRY_RUN_LIMITS, MODEL_RPM_LIMITS, ROOT_DIR } from '../src/constants'
|
||||
import { datasets } from '../src/datasets'
|
||||
import { ACCURACY_DATASETS } from '../src/datasets'
|
||||
import { evaluateQuestion, models } from '../src/evaluate'
|
||||
import { formatters } from '../src/formatters'
|
||||
import { formatters, supportsCSV } from '../src/formatters'
|
||||
import { generateQuestions } from '../src/questions'
|
||||
import { calculateFormatResults, calculateTokenCounts, generateAccuracyReport } from '../src/report'
|
||||
import { getAllModelResults, hasModelResults, saveModelResults } from '../src/storage'
|
||||
import { ensureDir } from '../src/utils'
|
||||
|
||||
// Constants
|
||||
const PROGRESS_UPDATE_INTERVAL = 10
|
||||
const RATE_LIMIT_INTERVAL_MS = 60_000
|
||||
|
||||
prompts.intro('Retrieval Accuracy Benchmark')
|
||||
|
||||
/**
|
||||
* Generate evaluation tasks for a model
|
||||
*/
|
||||
function generateEvaluationTasks(questions: Question[]): { question: Question, formatName: string }[] {
|
||||
const tasks: { question: Question, formatName: string }[] = []
|
||||
|
||||
for (const question of questions) {
|
||||
for (const [formatName] of Object.entries(formatters)) {
|
||||
// Skip CSV for datasets that don't support it
|
||||
const dataset = ACCURACY_DATASETS.find(d => d.name === question.dataset)
|
||||
if (formatName === 'csv' && dataset && !supportsCSV(dataset))
|
||||
continue
|
||||
|
||||
tasks.push({ question, formatName })
|
||||
}
|
||||
}
|
||||
|
||||
return tasks
|
||||
}
|
||||
|
||||
/**
|
||||
* Check which models already have saved results
|
||||
*/
|
||||
async function checkExistingResults(activeModels: typeof models) {
|
||||
const existingModelResults: Record<string, boolean> = {}
|
||||
|
||||
for (const model of activeModels) {
|
||||
const existingResult = await hasModelResults(model.modelId)
|
||||
if (existingResult)
|
||||
existingModelResults[model.modelId] = existingResult
|
||||
}
|
||||
|
||||
return existingModelResults
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a progress updater function
|
||||
*/
|
||||
function createProgressUpdater(spinner: ReturnType<typeof prompts.spinner>, total: number) {
|
||||
let completed = 0
|
||||
|
||||
return () => {
|
||||
completed++
|
||||
if (completed % PROGRESS_UPDATE_INTERVAL === 0 || completed === total) {
|
||||
const percent = ((completed / total) * 100).toFixed(1)
|
||||
spinner.message(`Progress: ${completed}/${total} (${percent}%)`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a rate-limited queue for model evaluation
|
||||
*/
|
||||
function createEvaluationQueue(modelId: string) {
|
||||
const rpmLimit = MODEL_RPM_LIMITS[modelId]
|
||||
|
||||
return new PQueue({
|
||||
concurrency: DEFAULT_CONCURRENCY,
|
||||
intervalCap: rpmLimit ?? Infinity,
|
||||
interval: rpmLimit ? RATE_LIMIT_INTERVAL_MS : 0,
|
||||
})
|
||||
}
|
||||
|
||||
// Prompt user to select which models to benchmark
|
||||
const modelChoices = models.map(({ modelId }) => ({
|
||||
value: modelId,
|
||||
@@ -37,15 +104,10 @@ const activeModels = models.filter(m => selectedModels.includes(m.modelId))
|
||||
prompts.log.info(`Selected ${activeModels.length} model(s): ${activeModels.map(m => m.modelId).join(', ')}`)
|
||||
|
||||
// Check which models already have results
|
||||
const existingModelResults: Record<string, boolean> = {}
|
||||
for (const model of activeModels) {
|
||||
const existingResult = await hasModelResults(model.modelId)
|
||||
if (existingResult)
|
||||
existingModelResults[model.modelId] = existingResult
|
||||
}
|
||||
const existingModelResults = await checkExistingResults(activeModels)
|
||||
|
||||
if (Object.keys(existingModelResults).length > 0) {
|
||||
prompts.log.info(`Found existing results for ${Object.values(existingModelResults).length} model(s)`)
|
||||
prompts.log.info(`Found existing results for ${Object.keys(existingModelResults).length} model(s)`)
|
||||
}
|
||||
|
||||
if (DRY_RUN) {
|
||||
@@ -75,31 +137,22 @@ for (const model of activeModels) {
|
||||
prompts.log.step(`Running benchmark for ${modelId}`)
|
||||
|
||||
// Generate evaluation tasks for this model
|
||||
const tasks: { question: Question, formatName: string }[] = []
|
||||
for (const question of questions) {
|
||||
for (const [formatName] of Object.entries(formatters)) {
|
||||
tasks.push({ question, formatName })
|
||||
}
|
||||
}
|
||||
const tasks = generateEvaluationTasks(questions)
|
||||
|
||||
const total = tasks.length
|
||||
const rpmLimit = MODEL_RPM_LIMITS[modelId]
|
||||
const queue = new PQueue({
|
||||
concurrency: DEFAULT_CONCURRENCY,
|
||||
intervalCap: rpmLimit ?? Infinity,
|
||||
interval: rpmLimit ? 60_000 : 0,
|
||||
})
|
||||
const queue = createEvaluationQueue(modelId)
|
||||
|
||||
const evalSpinner = prompts.spinner()
|
||||
evalSpinner.start(`Running ${total} evaluations (concurrency: ${DEFAULT_CONCURRENCY}, RPM limit: ${rpmLimit ?? 'unlimited'})`)
|
||||
|
||||
let completed = 0
|
||||
const updateProgress = createProgressUpdater(evalSpinner, total)
|
||||
|
||||
// Queue all tasks
|
||||
const modelResultPromises = tasks.map(task =>
|
||||
queue.add(async () => {
|
||||
// Format data on-demand
|
||||
const dataset = datasets.find(d => d.name === task.question.dataset)!
|
||||
const dataset = ACCURACY_DATASETS.find(d => d.name === task.question.dataset)!
|
||||
const formatter = formatters[task.formatName]!
|
||||
const formattedData = formatter(dataset.data)
|
||||
|
||||
@@ -111,11 +164,7 @@ for (const model of activeModels) {
|
||||
})
|
||||
|
||||
// Progress update after task completes
|
||||
completed++
|
||||
if (completed % 10 === 0 || completed === total) {
|
||||
const percent = ((completed / total) * 100).toFixed(1)
|
||||
evalSpinner.message(`Progress: ${completed}/${total} (${percent}%)`)
|
||||
}
|
||||
updateProgress()
|
||||
|
||||
return result
|
||||
}),
|
||||
@@ -154,5 +203,5 @@ await ensureDir(resultsDir)
|
||||
const outputFilePath = path.join(resultsDir, 'retrieval-accuracy.md')
|
||||
await fsp.writeFile(outputFilePath, accuracyReport)
|
||||
|
||||
prompts.log.info(`Report saved to: \`${path.relative(ROOT_DIR, outputFilePath)}\``)
|
||||
reportSpinner.stop('Report generation complete!')
|
||||
prompts.log.info(`Report saved to: \`${path.relative(ROOT_DIR, outputFilePath)}\``)
|
||||
|
||||
Reference in New Issue
Block a user