1. 前言
Gorm源码学习系列
此文是Gorm源码学习系列的第二篇,主要梳理下通过Gorm创建表的流程。
2. 创建行记录代码示例
gorm提供了以下几个接口来创建行记录
- 一次创建一行
func (db *DB) Create(value interface{}) (tx *DB)
- 批量创建
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB)
- 数据库不存在主键时创建,存在时更新
func (db *DB) Save(value interface{}) (tx *DB)
详细请看教程及源码finisher_api.go,这里使用func (db *DB) Create(value interface{}) (tx *DB)
来说明创建行记录等大致流程。
2.1 声明模型
type Stu struct { ID int64 `gorm:"column:id; primary_key" json:"id"` Age int64 `gorm:"column:age;"` Height int64 `gorm:"column:height;"` Weight int64 `gorm:"column:weight;"` } // 设置表名 func (Stu) TableName() string { return "t_student" }
模型代码的主要用途如下,
- 申明的表中有哪些列及每列的名称、特性等,如
gorm
标签指定每个字断对于的表的列名 - 通过实现
Tabler
接口指定了固定的表名,接口定义如下
type Tabler interface { TableName() string }
关于模型定义中更多的约定和约束等,请看教程。
出于分表等业务场景,我们并不希望固定模型等表名,gorm提供了func (db *DB) Table(name string, args ...interface{}) (tx *DB)
等方法
来动态指定表名,详情请看教程。
2.2 创建行
func main() { // 数据库连接, 具体查看https://www.cnblogs.com/amos01/p/16890747.html 连接数据库代码示例 db, _ := dbOpen() // 打开调试模式、会打印DML db = db.Debug() stu := &Stu{ Age: 18, Height: 185, Weight: 70, } db = db.Create(stu) fmt.Printf("Error:%v ID:%v RowsAffected:%vn", db.Error, stu.ID, db.RowsAffected) }
代码输出如下
$ go run main.go 2022/12/11 14:59:59 /Users/zbw/workspace/test/main.go:33 [1.910ms] [rows:1] INSERT INTO `t_student` (`age`,`height`,`weight`) VALUES (18,185,70) Error:<nil> ID:1027 RowsAffected:1
从代码输出可以看,行记录的ID为1027,连接数据库查询,结果如下。
mysql> select * from t_student where id = 1027G *************************** 1. row *************************** id: 1027 age: 18 height: 185 weight: 70 1 row in set (0.01 sec)
因此,我们带着以下问题来梳理下Gorm创建行记录的流程
- 如何从model到DML语句的
- 如何将ID写入到model的
3. 从Model到DML
func (db *DB) Create(value interface{}) (tx *DB)
的实现如下
// Create inserts value, returning the inserted data's primary key in value's id func (db *DB) Create(value interface{}) (tx *DB) { if db.CreateBatchSize > 0 { return db.CreateInBatches(value, db.CreateBatchSize) } tx = db.getInstance() tx.Statement.Dest = value return tx.callbacks.Create().Execute(tx) }
func (p *processor) Execute(db *DB) *DB
的实现比较长,具体代码见github
总结下来,做了两件主要的事情,
- 解析model获取表名、每列的定义等
- 执行钩子函数以及创建行函数
X
3.1 数据结构理解
gorm.Statement
查看gorm.Statement代码
// Statement statement type Statement struct { *DB TableExpr *clause.Expr Table string // 表名 Model interface{} // model定义 Unscoped bool Dest interface{} // model的另外一种表达,如map ReflectValue reflect.Value Clauses map[string]clause.Clause BuildClauses []string Distinct bool Selects []string // selected columns Omits []string // omit columns Joins []join Preloads map[string][]interface{} Settings sync.Map ConnPool ConnPool // 数据库连接 Schema *schema.Schema // 表结构化信息 Context context.Context RaiseErrorOnNotFound bool SkipHooks bool SQL strings.Builder // 最终的DML语句 Vars []interface{} // DML语句的参数值 CurDestIndex int // 批量创建/更新时,gorm当前操作的数组/slice的下标 attrs []interface{} assigns []interface{} scopes []func(*DB) *DB }
schema.Schem
查看schema.Schema代码
type Schema struct { Name string ModelType reflect.Type Table string // 表名 PrioritizedPrimaryField *Field DBNames []string // 表每列的名字 PrimaryFields []*Field PrimaryFieldDBNames []string // 表的主键列明 Fields []*Field // gorm自定义的model每个字短 FieldsByName map[string]*Field FieldsByDBName map[string]*Field FieldsWithDefaultDBValue []*Field // fields with default value assigned by database Relationships Relationships CreateClauses []clause.Interface // 创建行的子句 QueryClauses []clause.Interface UpdateClauses []clause.Interface DeleteClauses []clause.Interface BeforeCreate, AfterCreate bool BeforeUpdate, AfterUpdate bool BeforeDelete, AfterDelete bool BeforeSave, AfterSave bool AfterFind bool err error initialized chan struct{} namer Namer cacheStore *sync.Map }
schema.Field
查看schema.Field代码
// Field is the representation of model schema's field type Field struct { Name string // model的字段名 DBName string // 对应表的列名 BindNames []string DataType DataType GORMDataType DataType PrimaryKey bool AutoIncrement bool AutoIncrementIncrement int64 Creatable bool Updatable bool Readable bool AutoCreateTime TimeType AutoUpdateTime TimeType HasDefaultValue bool DefaultValue string DefaultValueInterface interface{} NotNull bool Unique bool Comment string Size int Precision int Scale int IgnoreMigration bool FieldType reflect.Type // 反射类型 IndirectFieldType reflect.Type // 反射类型 StructField reflect.StructField // model字段信息 Tag reflect.StructTag // tag TagSettings map[string]string Schema *Schema EmbeddedSchema *Schema OwnerSchema *Schema ReflectValueOf func(context.Context, reflect.Value) reflect.Value // 通过反射获取该字段的反射对象 ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool) // 通过反射获取该字段的值 get方法 Set func(context.Context, reflect.Value, interface{}) error // 通过反射设置该字段的值 set方法 Serializer SerializerInterface NewValuePool FieldNewValuePool }
clause.Interface
及clause.Clause
gorm定义了多种clause,包括
查看clause.Interface代码
// Interface clause interface type Interface interface { Name() string Build(Builder) MergeClause(*Clause) }
查看clause.Clause代码
// Clause type Clause struct { Name string // WHERE BeforeExpression Expression AfterNameExpression Expression AfterExpression Expression Expression Expression Builder ClauseBuilder }
3.2 解析Model
通过调用stmt.Parse(stmt.Model)
进行model解析
stmt.Parse(stmt.Model)
会调用到函数func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error)
进行解析。
详细代码见schema.go,下面列举重要的几个点。
- 通过反射判断
dest interface{}
是否为reflect.Struct
- 通过接口获取表名,其中stu实现了
Tabler
接口
// 获取表名 modelValue := reflect.New(modelType) tableName := namer.TableName(modelType.Name()) if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { tableName = tabler.TableName(namer) } if en, ok := namer.(embeddedNamer); ok { tableName = en.Table } if specialTableName != "" && specialTableName != tableName { tableName = specialTableName }
- 解析model每个字段
// 通过反射获取每个字段 for i := 0; i < modelType.NumField(); i++ { if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { // 解析每个字段 if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) } else { schema.Fields = append(schema.Fields, field) } } }
- 放到map方便查找,并且通过
func (field *Field) setupValuerAndSetter()
初始化每个Field的ReflectValueOf
、ValueOf
、Set
方法。for _, field := range schema.Fields { if field.DBName == "" && field.DataType != "" { field.DBName = namer.ColumnName(schema.Table, field.Name) } if field.DBName != "" { // nonexistence or shortest path or first appear prioritized if has permission if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { if _, ok := schema.FieldsByDBName[field.DBName]; !ok { schema.DBNames = append(schema.DBNames, field.DBName) } // gorm tag字段到field的映射 schema.FieldsByDBName[field.DBName] = field // model 字段到field的映射 schema.FieldsByName[field.Name] = field if v != nil && v.PrimaryKey { for idx, f := range schema.PrimaryFields { if f == v { schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) } } } // 主键 if field.PrimaryKey { schema.PrimaryFields = append(schema.PrimaryFields, field) } } } if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { schema.FieldsByName[field.Name] = field } // 挂载字段的set方法和get方法 field.setupValuerAndSetter() }
值得一提的是,每个model解析后的结果是一致,可以将结果解析的结构缓存下来,并且通过chan
来解决并发的问题。
解析model之后,通过process
获取到钩子函数及创建行的函数,具体代码见Github
for _, f := range p.fns { f(db) }
3.3 执行钩子函数及创建行的函数
创建行的函数及对应的钩子函数位于create.go
- 创建行记录
if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build(db.Statement.BuildClauses...) }
这里插入两个clause.Clause
,分别为clause.Insert
以及clause.Values
,然后调用这两种clause.Clause
的build
方法生成SQL
语句。
首先,看下ConvertToCreateValues
的实现,这里只截取部分代码
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} // 获取每一列的名字 for _, db := range stmt.Schema.DBNames { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) { values.Columns = append(values.Columns, clause.Column{Name: db}) } } } // 获取每一列对应的值 switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: case reflect.Struct: values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] // func (field *Field) setupValuerAndSetter() 挂载的方法 if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } else if field.AutoUpdateTime > 0 && updateTrackTime { stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } }
通过ConvertToCreateValues
获取了每一列的名称及对应的值。
接下来,看clause.Clause
到SQL
语句的过程。
遍历加入clause
,此时分别为clause.Insert
以及clause.Values
// Build build sql with clauses names func (stmt *Statement) Build(clauses ...string) { var firstClauseWritten bool for _, name := range clauses { if c, ok := stmt.Clauses[name]; ok { // 代码有删减 c.Build(stmt) } } }
接着调用func (c Clause) Build(builder Builder)
// Build build clause func (c Clause) Build(builder Builder) { // 有删减 // c为clause.Insert以及clause.Values if c.Name != "" { // builder写入 INSERT 或者 VALUES builder.WriteString(c.Name) builder.WriteByte(' ') } // 通过clause.Insert以及clause.Values的MergeClause函数,c.Expression为clause.Insert以及clause.Values // 因此,这里调用clause.Insert或者clause.Values的Build的方法 c.Expression.Build(builder) }
接下来分别看clause.Insert
以及clause.Values
// Build build insert clause func (insert Insert) Build(builder Builder) { // builder写入INTO,此时builder为INSERT INTO builder.WriteString("INTO ") // builder写入表名 builder.WriteQuoted(currentTable) }
从调用的链路可以得出,这里builder
为stmt *Statement
,并且currentTable
类型为clause.Table
,因此
// WriteQuoted write quoted value func (stmt *Statement) WriteQuoted(value interface{}) { stmt.QuoteTo(&stmt.SQL, value) } // QuoteTo write quoted value to writer 有删减 func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { write := func(raw bool, str string) { // mysql驱动Dialector stmt.DB.Dialector.QuoteTo(writer, str) } switch v := field.(type) { case clause.Table: write(v.Raw, stmt.Table) } }
至此,builder
已经拼装出INSERT INTO `t_student`
,解析来再看clause.Values
的build
方法
// Build build from clause func (values Values) Build(builder Builder) { if len(values.Columns) > 0 { builder.WriteByte('(') for idx, column := range values.Columns { if idx > 0 { builder.WriteByte(',') } builder.WriteQuoted(column) } builder.WriteByte(')') builder.WriteString(" VALUES ") for idx, value := range values.Values { if idx > 0 { builder.WriteByte(',') } builder.WriteByte('(') builder.AddVar(builder, value...) builder.WriteByte(')') } } else { builder.WriteString("DEFAULT VALUES") } }
func (values Values) Build(builder Builder)
取出所有列名和列对应的值
最终builder
拼装成例子的完整SQL语句INSERT INTO `t_student` (`age`,`height`,`weight`) VALUES (18,185,70)
有了SQL语句,就可以执行了
result, err := db.Statement.ConnPool.ExecContext( db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., )
通过前一面学习,db.Statement.ConnPool
的值为sql.DB
,实际执行的函数为func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error)
至此,从Model到DML到流程已经完成。
4. 将ID写入到model的
看返回参数sql.Result
,因此通过LastInsertId() (int64, error)
可以获取到插入行的ID值。
// A Result summarizes an executed SQL command. type Result interface { // LastInsertId returns the integer generated by the database // in response to a command. Typically this will be from an // "auto increment" column when inserting a new row. Not all // databases support this feature, and the syntax of such // statements varies. LastInsertId() (int64, error) // RowsAffected returns the number of rows affected by an // update, insert, or delete. Not every database or database // driver may support this. RowsAffected() (int64, error) }
获取到刚插入的行ID值,再通过反射写入model的ID字段即可。
db.RowsAffected, _ = result.RowsAffected() if db.RowsAffected != 0 && db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { insertID, err := result.LastInsertId() switch db.Statement.ReflectValue.Kind() { case reflect.Struct: _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { // 通过反射更新ID db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) } } }
5. 总结
使用反射解析Model,获得每个成员对应的表的列名、值等信息。
定义SQL各个关键词如INSERT
、VALUES
、FROM
、DELETE
的结构体,并实现clause.Interface
接口
进而对SQL语句的构造进行抽象封装。