update migrate
This commit is contained in:
70
ent/ent.go
70
ent/ent.go
@ -3,13 +3,16 @@
|
||||
package ent
|
||||
|
||||
import (
|
||||
"api/ent/card"
|
||||
"api/ent/group"
|
||||
"api/ent/ma"
|
||||
"api/ent/ue"
|
||||
"api/ent/user"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"t/ent/card"
|
||||
"t/ent/group"
|
||||
"t/ent/user"
|
||||
"sync"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
@ -62,35 +65,33 @@ func NewTxContext(parent context.Context, tx *Tx) context.Context {
|
||||
}
|
||||
|
||||
// OrderFunc applies an ordering on the sql selector.
|
||||
// Deprecated: Use Asc/Desc functions or the package builders instead.
|
||||
type OrderFunc func(*sql.Selector)
|
||||
|
||||
// columnChecker returns a function indicates if the column exists in the given column.
|
||||
func columnChecker(table string) func(string) error {
|
||||
checks := map[string]func(string) bool{
|
||||
card.Table: card.ValidColumn,
|
||||
group.Table: group.ValidColumn,
|
||||
user.Table: user.ValidColumn,
|
||||
}
|
||||
check, ok := checks[table]
|
||||
if !ok {
|
||||
return func(string) error {
|
||||
return fmt.Errorf("unknown table %q", table)
|
||||
}
|
||||
}
|
||||
return func(column string) error {
|
||||
if !check(column) {
|
||||
return fmt.Errorf("unknown column %q for table %q", column, table)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var (
|
||||
initCheck sync.Once
|
||||
columnCheck sql.ColumnCheck
|
||||
)
|
||||
|
||||
// columnChecker checks if the column exists in the given table.
|
||||
func checkColumn(table, column string) error {
|
||||
initCheck.Do(func() {
|
||||
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
|
||||
card.Table: card.ValidColumn,
|
||||
group.Table: group.ValidColumn,
|
||||
ma.Table: ma.ValidColumn,
|
||||
ue.Table: ue.ValidColumn,
|
||||
user.Table: user.ValidColumn,
|
||||
})
|
||||
})
|
||||
return columnCheck(table, column)
|
||||
}
|
||||
|
||||
// Asc applies the given fields in ASC order.
|
||||
func Asc(fields ...string) OrderFunc {
|
||||
func Asc(fields ...string) func(*sql.Selector) {
|
||||
return func(s *sql.Selector) {
|
||||
check := columnChecker(s.TableName())
|
||||
for _, f := range fields {
|
||||
if err := check(f); err != nil {
|
||||
if err := checkColumn(s.TableName(), f); err != nil {
|
||||
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)})
|
||||
}
|
||||
s.OrderBy(sql.Asc(s.C(f)))
|
||||
@ -99,11 +100,10 @@ func Asc(fields ...string) OrderFunc {
|
||||
}
|
||||
|
||||
// Desc applies the given fields in DESC order.
|
||||
func Desc(fields ...string) OrderFunc {
|
||||
func Desc(fields ...string) func(*sql.Selector) {
|
||||
return func(s *sql.Selector) {
|
||||
check := columnChecker(s.TableName())
|
||||
for _, f := range fields {
|
||||
if err := check(f); err != nil {
|
||||
if err := checkColumn(s.TableName(), f); err != nil {
|
||||
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)})
|
||||
}
|
||||
s.OrderBy(sql.Desc(s.C(f)))
|
||||
@ -135,8 +135,7 @@ func Count() AggregateFunc {
|
||||
// Max applies the "max" aggregation function on the given field of each group.
|
||||
func Max(field string) AggregateFunc {
|
||||
return func(s *sql.Selector) string {
|
||||
check := columnChecker(s.TableName())
|
||||
if err := check(field); err != nil {
|
||||
if err := checkColumn(s.TableName(), field); err != nil {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)})
|
||||
return ""
|
||||
}
|
||||
@ -147,8 +146,7 @@ func Max(field string) AggregateFunc {
|
||||
// Mean applies the "mean" aggregation function on the given field of each group.
|
||||
func Mean(field string) AggregateFunc {
|
||||
return func(s *sql.Selector) string {
|
||||
check := columnChecker(s.TableName())
|
||||
if err := check(field); err != nil {
|
||||
if err := checkColumn(s.TableName(), field); err != nil {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)})
|
||||
return ""
|
||||
}
|
||||
@ -159,8 +157,7 @@ func Mean(field string) AggregateFunc {
|
||||
// Min applies the "min" aggregation function on the given field of each group.
|
||||
func Min(field string) AggregateFunc {
|
||||
return func(s *sql.Selector) string {
|
||||
check := columnChecker(s.TableName())
|
||||
if err := check(field); err != nil {
|
||||
if err := checkColumn(s.TableName(), field); err != nil {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)})
|
||||
return ""
|
||||
}
|
||||
@ -171,8 +168,7 @@ func Min(field string) AggregateFunc {
|
||||
// Sum applies the "sum" aggregation function on the given field of each group.
|
||||
func Sum(field string) AggregateFunc {
|
||||
return func(s *sql.Selector) string {
|
||||
check := columnChecker(s.TableName())
|
||||
if err := check(field); err != nil {
|
||||
if err := checkColumn(s.TableName(), field); err != nil {
|
||||
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)})
|
||||
return ""
|
||||
}
|
||||
@ -509,7 +505,7 @@ func withHooks[V Value, M any, PM interface {
|
||||
return exec(ctx)
|
||||
}
|
||||
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
|
||||
mutationT, ok := m.(PM)
|
||||
mutationT, ok := any(m).(PM)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected mutation type %T", m)
|
||||
}
|
||||
|
Reference in New Issue
Block a user