|
4 | 4 | "context" |
5 | 5 | "errors" |
6 | 6 | "fmt" |
7 | | - "strings" |
8 | 7 | "time" |
9 | 8 |
|
10 | 9 | "github.com/QuantumNous/new-api/common" |
@@ -309,9 +308,15 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName |
309 | 308 | tx = LOG_DB.Where("logs.type = ?", logType) |
310 | 309 | } |
311 | 310 |
|
312 | | - tx = applyLogContainsFilter(tx, "logs.model_name", modelName) |
313 | | - tx = applyLogContainsFilter(tx, "logs.username", username) |
314 | | - tx = applyLogContainsFilter(tx, "logs.token_name", tokenName) |
| 311 | + if modelName != "" { |
| 312 | + tx = tx.Where("logs.model_name like ?", modelName) |
| 313 | + } |
| 314 | + if username != "" { |
| 315 | + tx = tx.Where("logs.username = ?", username) |
| 316 | + } |
| 317 | + if tokenName != "" { |
| 318 | + tx = tx.Where("logs.token_name = ?", tokenName) |
| 319 | + } |
315 | 320 | if requestId != "" { |
316 | 321 | tx = tx.Where("logs.request_id = ?", requestId) |
317 | 322 | } |
@@ -392,8 +397,16 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int |
392 | 397 | tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType) |
393 | 398 | } |
394 | 399 |
|
395 | | - tx = applyLogContainsFilter(tx, "logs.model_name", modelName) |
396 | | - tx = applyLogContainsFilter(tx, "logs.token_name", tokenName) |
| 400 | + if modelName != "" { |
| 401 | + modelNamePattern, err := sanitizeLikePattern(modelName) |
| 402 | + if err != nil { |
| 403 | + return nil, 0, err |
| 404 | + } |
| 405 | + tx = tx.Where("logs.model_name LIKE ? ESCAPE '!'", modelNamePattern) |
| 406 | + } |
| 407 | + if tokenName != "" { |
| 408 | + tx = tx.Where("logs.token_name = ?", tokenName) |
| 409 | + } |
397 | 410 | if requestId != "" { |
398 | 411 | tx = tx.Where("logs.request_id = ?", requestId) |
399 | 412 | } |
@@ -430,42 +443,34 @@ type Stat struct { |
430 | 443 | Tpm int `json:"tpm"` |
431 | 444 | } |
432 | 445 |
|
433 | | -func logContainsPattern(input string) (string, bool) { |
434 | | - input = strings.TrimSpace(input) |
435 | | - if input == "" { |
436 | | - return "", false |
437 | | - } |
438 | | - |
439 | | - replacer := strings.NewReplacer("!", "!!", "%", "!%", "_", "!_") |
440 | | - return "%" + replacer.Replace(input) + "%", true |
441 | | -} |
442 | | - |
443 | | -func applyLogContainsFilter(tx *gorm.DB, column string, value string) *gorm.DB { |
444 | | - pattern, ok := logContainsPattern(value) |
445 | | - if !ok { |
446 | | - return tx |
447 | | - } |
448 | | - return tx.Where(column+" LIKE ? ESCAPE '!'", pattern) |
449 | | -} |
450 | | - |
451 | 446 | func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat, err error) { |
452 | 447 | tx := LOG_DB.Table("logs").Select("sum(quota) quota") |
453 | 448 |
|
454 | 449 | // 为rpm和tpm创建单独的查询 |
455 | 450 | rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm") |
456 | 451 |
|
457 | | - tx = applyLogContainsFilter(tx, "username", username) |
458 | | - rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "username", username) |
459 | | - tx = applyLogContainsFilter(tx, "token_name", tokenName) |
460 | | - rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "token_name", tokenName) |
| 452 | + if username != "" { |
| 453 | + tx = tx.Where("username = ?", username) |
| 454 | + rpmTpmQuery = rpmTpmQuery.Where("username = ?", username) |
| 455 | + } |
| 456 | + if tokenName != "" { |
| 457 | + tx = tx.Where("token_name = ?", tokenName) |
| 458 | + rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName) |
| 459 | + } |
461 | 460 | if startTimestamp != 0 { |
462 | 461 | tx = tx.Where("created_at >= ?", startTimestamp) |
463 | 462 | } |
464 | 463 | if endTimestamp != 0 { |
465 | 464 | tx = tx.Where("created_at <= ?", endTimestamp) |
466 | 465 | } |
467 | | - tx = applyLogContainsFilter(tx, "model_name", modelName) |
468 | | - rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "model_name", modelName) |
| 466 | + if modelName != "" { |
| 467 | + modelNamePattern, err := sanitizeLikePattern(modelName) |
| 468 | + if err != nil { |
| 469 | + return stat, err |
| 470 | + } |
| 471 | + tx = tx.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern) |
| 472 | + rpmTpmQuery = rpmTpmQuery.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern) |
| 473 | + } |
469 | 474 | if channel != 0 { |
470 | 475 | tx = tx.Where("channel_id = ?", channel) |
471 | 476 | rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel) |
|
0 commit comments