Skip to main content

LangChain Go SQLDatabase 工具使用指南

概述

github.com/tmc/langchaingo/tools/sqldatabase 是LangChain Go版本中用于操作SQL数据库的工具包。它提供了一个统一的接口来与各种数据库(MySQL、PostgreSQL、SQLite等)进行交互。

安装和导入

// go.mod
require (
github.com/tmc/langchaingo v0.1.12
github.com/go-sql-driver/mysql v1.7.1
)

// 导入
import (
"github.com/tmc/langchaingo/tools/sqldatabase"
_ "github.com/tmc/langchaingo/tools/sqldatabase/mysql" // MySQL驱动
_ "github.com/go-sql-driver/mysql" // MySQL数据库驱动
)

基本用法

1. 创建数据库连接

func main() {
// MySQL连接字符串
dsn := "username:password@tcp(localhost:3306)/database_name"

// 创建SQL数据库工具
sqlTool, err := sqldatabase.NewSQLDatabaseWithDSN("mysql", dsn, nil)
if err != nil {
log.Fatal("Failed to create SQL database tool:", err)
}
defer sqlTool.Close()

ctx := context.Background()
// ... 使用sqlTool
}

2. 主要方法

获取表信息

// 获取所有表的结构信息
schema, err := sqlTool.TableInfo(ctx, nil)

// 获取特定表的信息
tableInfo, err := sqlTool.TableInfo(ctx, []string{"users", "orders"})

执行查询

// 执行SQL查询
result, err := sqlTool.Query(ctx, "SELECT * FROM users LIMIT 10")
if err != nil {
log.Printf("Query failed: %v", err)
} else {
fmt.Printf("Result:\n%s\n", result)
}

获取表名列表

// 获取所有表名
tableNames := sqlTool.TableNames()
fmt.Printf("Tables: %v\n", tableNames)

3. 高级配置

忽略特定表

// 创建时忽略某些表
ignoreTables := map[string]struct{}{
"migration_logs": {},
"system_logs": {},
}

sqlTool, err := sqldatabase.NewSQLDatabaseWithDSN("mysql", dsn, ignoreTables)

自定义采样行数

// SQLDatabase结构体包含SampleRowsNumber字段来控制示例行数
// 默认值是3行,可以通过直接访问修改
sqlTool.SampleRowsNumber = 5 // 显示5行示例数据

与LLM集成

1. 使用SQL查询链

import (
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/llms/openai"
)

func useSQLChain() {
// 创建LLM
llm, err := openai.New()
if err != nil {
log.Fatal(err)
}

// 创建SQL工具
sqlTool, err := sqldatabase.NewSQLDatabaseWithDSN("mysql", dsn, nil)
if err != nil {
log.Fatal(err)
}
defer sqlTool.Close()

// 创建SQL查询链
sqlChain := chains.NewSQLDatabaseChain(llm, sqlTool)

// 自然语言查询
result, err := sqlChain.Call(ctx, map[string]any{
"query": "How many users are in the database?",
})
}

2. 自定义提示模板

func customPromptExample(sqlTool *sqldatabase.SQLDatabase, llm llms.LLM) {
ctx := context.Background()

// 获取数据库schema
schema, err := sqlTool.TableInfo(ctx, nil)
if err != nil {
return
}

// 自定义提示
prompt := fmt.Sprintf(`
Given the following database schema:
%s

Generate a SQL query to answer: %s

Return only the SQL query.
`, schema, "Find all active users")

// 使用LLM生成SQL
response, err := llm.Call(ctx, prompt)
if err != nil {
return
}

// 执行生成的SQL
result, err := sqlTool.Query(ctx, response)
fmt.Printf("Result: %s\n", result)
}

实用示例

1. 数据库统计查询

// 用户统计
userStats, _ := sqlTool.Query(ctx,
"SELECT COUNT(*) as total_users, MAX(created_at) as latest_user FROM users")

// 状态分布
statusDist, _ := sqlTool.Query(ctx,
"SELECT status, COUNT(*) as count FROM tasks GROUP BY status")

2. 复杂联接查询

complexQuery := `
SELECT
c.title as channel_title,
COUNT(v.id) as video_count,
c.status
FROM tb_channel c
LEFT JOIN tb_video v ON c.channel_id = v.channel_id
GROUP BY c.channel_id, c.title, c.status
ORDER BY video_count DESC
LIMIT 10
`
result, _ := sqlTool.Query(ctx, complexQuery)

3. 事务处理

queries := []string{
"START TRANSACTION",
"UPDATE users SET last_login = NOW() WHERE id = 1",
"INSERT INTO user_activity (user_id, activity) VALUES (1, 'login')",
"COMMIT",
}

for _, query := range queries {
_, err := sqlTool.Query(ctx, query)
if err != nil {
sqlTool.Query(ctx, "ROLLBACK")
break
}
}

支持的数据库

  • MySQL: github.com/tmc/langchaingo/tools/sqldatabase/mysql
  • PostgreSQL: github.com/tmc/langchaingo/tools/sqldatabase/postgresql
  • SQLite: github.com/tmc/langchaingo/tools/sqldatabase/sqlite3

环境变量配置

# .env 文件
MYSQL_DSN="username:password@tcp(localhost:3306)/database_name"
OPENAI_API_KEY="your_openai_api_key"
// 在代码中使用
dsn := os.Getenv("MYSQL_DSN")
if dsn == "" {
dsn = "default_connection_string"
}

错误处理

// 数据库连接错误
sqlTool, err := sqldatabase.NewSQLDatabaseWithDSN("mysql", dsn, nil)
if err != nil {
log.Fatal("Database connection failed:", err)
}

// SQL执行错误
result, err := sqlTool.Query(ctx, sqlQuery)
if err != nil {
log.Printf("SQL execution failed: %v", err)
return
}

最佳实践

  1. 连接管理: 总是使用 defer sqlTool.Close() 确保连接关闭
  2. 错误处理: 对所有数据库操作进行适当的错误处理
  3. SQL注入防护: 使用参数化查询(虽然当前API不直接支持,需要手动转义)
  4. 性能优化: 对大表使用LIMIT限制返回行数
  5. Schema缓存: 考虑缓存schema信息以提高性能

与其他工具集成

1. 与Prometheus监控集成

// 监控查询执行时间
start := time.Now()
result, err := sqlTool.Query(ctx, query)
duration := time.Since(start)
// 记录指标...

2. 与日志系统集成

// 结构化日志记录
log.WithFields(log.Fields{
"query": query,
"duration": duration,
"error": err,
}).Info("SQL query executed")

这个工具包提供了强大而灵活的数据库操作能力,特别适合与AI/LLM应用集成,实现自然语言到SQL的转换。