LangChain Go SQLDatabase 工具使用指南
概述
github.com/tmc/langchaingo/tools/sqldatabase 是LangChain Go版本中用于操作SQL数据库的工具包。它提供了一个统一的接口来与各种数据库(MySQL、PostgreSQL、SQLite等)进行交互。
安装和导入
// go.mod
require (
github.com/tmc/langchaingo v0.1.12
github.com/go-sql-driver/mysql v1.7.1
)
// 导入
import (
"github.com/tmc/langchaingo/tools/sqldatabase"
_ "github.com/tmc/langchaingo/tools/sqldatabase/mysql" // MySQL驱动
_ "github.com/go-sql-driver/mysql" // MySQL数据库驱动
)
基本用法
1. 创建数据库连接
func main() {
// MySQL连接字符串
dsn := "username:password@tcp(localhost:3306)/database_name"
// 创建SQL数据库工具
sqlTool, err := sqldatabase.NewSQLDatabaseWithDSN("mysql", dsn, nil)
if err != nil {
log.Fatal("Failed to create SQL database tool:", err)
}
defer sqlTool.Close()
ctx := context.Background()
// ... 使用sqlTool
}
2. 主要方法
获取表信息
// 获取所有表的结构信息
schema, err := sqlTool.TableInfo(ctx, nil)
// 获取特定表的信息
tableInfo, err := sqlTool.TableInfo(ctx, []string{"users", "orders"})
执行查询
// 执行SQL查询
result, err := sqlTool.Query(ctx, "SELECT * FROM users LIMIT 10")
if err != nil {
log.Printf("Query failed: %v", err)
} else {
fmt.Printf("Result:\n%s\n", result)
}
获取表名列表
// 获取所有表名
tableNames := sqlTool.TableNames()
fmt.Printf("Tables: %v\n", tableNames)
3. 高级配置
忽略特定表
// 创建时忽略某些表
ignoreTables := map[string]struct{}{
"migration_logs": {},
"system_logs": {},
}
sqlTool, err := sqldatabase.NewSQLDatabaseWithDSN("mysql", dsn, ignoreTables)
自定义采样行数
// SQLDatabase结构体包含SampleRowsNumber字段来控制示例行数
// 默认值是3行,可以通过直接访问修改
sqlTool.SampleRowsNumber = 5 // 显示5行示例数据
与LLM集成
1. 使用SQL查询链
import (
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/llms/openai"
)
func useSQLChain() {
// 创建LLM
llm, err := openai.New()
if err != nil {
log.Fatal(err)
}
// 创建SQL工具
sqlTool, err := sqldatabase.NewSQLDatabaseWithDSN("mysql", dsn, nil)
if err != nil {
log.Fatal(err)
}
defer sqlTool.Close()
// 创建SQL查询链
sqlChain := chains.NewSQLDatabaseChain(llm, sqlTool)
// 自然语言查询
result, err := sqlChain.Call(ctx, map[string]any{
"query": "How many users are in the database?",
})
}
2. 自定义提示模板
func customPromptExample(sqlTool *sqldatabase.SQLDatabase, llm llms.LLM) {
ctx := context.Background()
// 获取数据库schema
schema, err := sqlTool.TableInfo(ctx, nil)
if err != nil {
return
}
// 自定义提示
prompt := fmt.Sprintf(`
Given the following database schema:
%s
Generate a SQL query to answer: %s
Return only the SQL query.
`, schema, "Find all active users")
// 使用LLM生成SQL
response, err := llm.Call(ctx, prompt)
if err != nil {
return
}
// 执行生成的SQL
result, err := sqlTool.Query(ctx, response)
fmt.Printf("Result: %s\n", result)
}
实用示例
1. 数据库统计查询
// 用户统计
userStats, _ := sqlTool.Query(ctx,
"SELECT COUNT(*) as total_users, MAX(created_at) as latest_user FROM users")
// 状态分布
statusDist, _ := sqlTool.Query(ctx,
"SELECT status, COUNT(*) as count FROM tasks GROUP BY status")
2. 复杂联接查询
complexQuery := `
SELECT
c.title as channel_title,
COUNT(v.id) as video_count,
c.status
FROM tb_channel c
LEFT JOIN tb_video v ON c.channel_id = v.channel_id
GROUP BY c.channel_id, c.title, c.status
ORDER BY video_count DESC
LIMIT 10
`
result, _ := sqlTool.Query(ctx, complexQuery)
3. 事务处理
queries := []string{
"START TRANSACTION",
"UPDATE users SET last_login = NOW() WHERE id = 1",
"INSERT INTO user_activity (user_id, activity) VALUES (1, 'login')",
"COMMIT",
}
for _, query := range queries {
_, err := sqlTool.Query(ctx, query)
if err != nil {
sqlTool.Query(ctx, "ROLLBACK")
break
}
}
支持的数据库
- MySQL:
github.com/tmc/langchaingo/tools/sqldatabase/mysql - PostgreSQL:
github.com/tmc/langchaingo/tools/sqldatabase/postgresql - SQLite:
github.com/tmc/langchaingo/tools/sqldatabase/sqlite3
环境变量配置
# .env 文件
MYSQL_DSN="username:password@tcp(localhost:3306)/database_name"
OPENAI_API_KEY="your_openai_api_key"
// 在代码中使用
dsn := os.Getenv("MYSQL_DSN")
if dsn == "" {
dsn = "default_connection_string"
}
错误处理
// 数据库连接错误
sqlTool, err := sqldatabase.NewSQLDatabaseWithDSN("mysql", dsn, nil)
if err != nil {
log.Fatal("Database connection failed:", err)
}
// SQL执行错误
result, err := sqlTool.Query(ctx, sqlQuery)
if err != nil {
log.Printf("SQL execution failed: %v", err)
return
}
最佳实践
- 连接管理: 总是使用
defer sqlTool.Close()确保连接关闭 - 错误处理: 对所有数据库操作进行适当的错误处理
- SQL注入防护: 使用参数化查询(虽然当前API不直接支持,需要手动转义)
- 性能优化: 对大表使用LIMIT限制返回行数
- Schema缓存: 考虑缓存schema信息以提高性能
与其他工具集成
1. 与Prometheus监控集成
// 监控查询执行时间
start := time.Now()
result, err := sqlTool.Query(ctx, query)
duration := time.Since(start)
// 记录指标...
2. 与日志系统集成
// 结构化日志记录
log.WithFields(log.Fields{
"query": query,
"duration": duration,
"error": err,
}).Info("SQL query executed")
这个工具包提供了强大而灵活的数据库操作能力,特别适合与AI/LLM应用集成,实现自然语言到SQL的转换。