Skip to content

Commit aa5ebc5

Browse files
authored
Bugfix/issue 34 col name conflict (#35)
1 parent 31be7b2 commit aa5ebc5

File tree

8 files changed

+306
-117
lines changed

8 files changed

+306
-117
lines changed

core/dbio/database/database.go

Lines changed: 84 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,36 +1481,70 @@ func (conn *BaseConn) GetTableColumns(table *Table, fields ...string) (columns i
14811481

14821482
// if fields provided, check if exists in table
14831483
colMap := map[string]map[string]any{}
1484-
for _, rec := range colData.Records() {
1485-
colName := cast.ToString(rec["column_name"])
1486-
colMap[strings.ToLower(colName)] = rec
1484+
caseSensitive := conn.GetType().DBNameCaseSensitive()
1485+
1486+
if caseSensitive {
1487+
for _, rec := range colData.Records() {
1488+
colName := cast.ToString(rec["column_name"])
1489+
colMap[colName] = rec
1490+
}
1491+
} else {
1492+
for _, rec := range colData.Records() {
1493+
colName := cast.ToString(rec["column_name"])
1494+
colMap[strings.ToLower(colName)] = rec
1495+
}
14871496
}
14881497

14891498
var colTypes []ColumnType
14901499

14911500
// if fields provided, filter, keep order
14921501
if len(fields) > 0 {
1493-
for _, field := range fields {
1494-
rec, ok := colMap[strings.ToLower(field)]
1495-
if !ok {
1496-
err = g.Error(
1497-
"provided field '%s' not found in table %s",
1498-
strings.ToLower(field), table.FullName(),
1499-
)
1500-
return
1501-
}
1502+
if caseSensitive {
1503+
for _, field := range fields {
1504+
rec, ok := colMap[(field)]
1505+
if !ok {
1506+
err = g.Error(
1507+
"provided field '%s' not found in table %s",
1508+
(field), table.FullName(),
1509+
)
1510+
return
1511+
}
15021512

1503-
if conn.Type == dbio.TypeDbSnowflake {
1504-
rec["data_type"], rec["precision"], rec["scale"] = parseSnowflakeDataType(rec)
1513+
if conn.Type == dbio.TypeDbSnowflake {
1514+
rec["data_type"], rec["precision"], rec["scale"] = parseSnowflakeDataType(rec)
1515+
}
1516+
1517+
colTypes = append(colTypes, ColumnType{
1518+
Name: cast.ToString(rec["column_name"]),
1519+
DatabaseTypeName: cast.ToString(rec["data_type"]),
1520+
Precision: cast.ToInt(rec["precision"]),
1521+
Scale: cast.ToInt(rec["scale"]),
1522+
Sourced: true,
1523+
})
15051524
}
1525+
} else {
1526+
for _, field := range fields {
1527+
rec, ok := colMap[strings.ToLower(field)]
1528+
if !ok {
1529+
err = g.Error(
1530+
"provided field '%s' not found in table %s",
1531+
strings.ToLower(field), table.FullName(),
1532+
)
1533+
return
1534+
}
15061535

1507-
colTypes = append(colTypes, ColumnType{
1508-
Name: cast.ToString(rec["column_name"]),
1509-
DatabaseTypeName: cast.ToString(rec["data_type"]),
1510-
Precision: cast.ToInt(rec["precision"]),
1511-
Scale: cast.ToInt(rec["scale"]),
1512-
Sourced: true,
1513-
})
1536+
if conn.Type == dbio.TypeDbSnowflake {
1537+
rec["data_type"], rec["precision"], rec["scale"] = parseSnowflakeDataType(rec)
1538+
}
1539+
1540+
colTypes = append(colTypes, ColumnType{
1541+
Name: cast.ToString(rec["column_name"]),
1542+
DatabaseTypeName: cast.ToString(rec["data_type"]),
1543+
Precision: cast.ToInt(rec["precision"]),
1544+
Scale: cast.ToInt(rec["scale"]),
1545+
Sourced: true,
1546+
})
1547+
}
15141548
}
15151549
} else {
15161550
colTypes = lo.Map(colData.Records(), func(rec map[string]interface{}, i int) ColumnType {
@@ -2052,31 +2086,38 @@ func (conn *BaseConn) CastColumnsForSelect(srcColumns iop.Columns, tgtColumns io
20522086
// ValidateColumnNames verifies that source fields are present in the target table
20532087
// It will return quoted field names as `newColNames`, the same length as `colNames`
20542088
func (conn *BaseConn) ValidateColumnNames(tgtCols iop.Columns, colNames []string, quote bool) (newCols iop.Columns, err error) {
2055-
2056-
tgtFields := map[string]string{}
2057-
for _, colName := range tgtCols.Names() {
2058-
colName = conn.Self().Unquote(colName)
2059-
if quote {
2060-
tgtFields[strings.ToLower(colName)] = conn.Self().Quote(colName)
2061-
} else {
2062-
tgtFields[strings.ToLower(colName)] = colName
2063-
}
2064-
}
2065-
20662089
mismatches := []string{}
2067-
for _, colName := range colNames {
2068-
newCol := tgtCols.GetColumn(colName)
2069-
if newCol == nil || newCol.Name == "" {
2070-
// src field is missing in tgt field
2071-
mismatches = append(mismatches, g.F("source field '%s' is missing in target table", colName))
2072-
continue
2090+
caseSensitive := conn.GetType().DBNameCaseSensitive()
2091+
if caseSensitive {
2092+
for _, colName := range colNames {
2093+
newCol := tgtCols.GetColumnWithOriginalCase(colName)
2094+
if newCol == nil || newCol.Name == "" {
2095+
// src field is missing in tgt field
2096+
mismatches = append(mismatches, g.F("source field '%s' is missing in target table", colName))
2097+
continue
2098+
}
2099+
if quote {
2100+
newCol.Name = conn.Self().Quote(newCol.Name)
2101+
} else {
2102+
newCol.Name = conn.Self().Unquote(newCol.Name)
2103+
}
2104+
newCols = append(newCols, *newCol)
20732105
}
2074-
if quote {
2075-
newCol.Name = conn.Self().Quote(newCol.Name)
2076-
} else {
2077-
newCol.Name = conn.Self().Unquote(newCol.Name)
2106+
} else {
2107+
for _, colName := range colNames {
2108+
newCol := tgtCols.GetColumn(colName)
2109+
if newCol == nil || newCol.Name == "" {
2110+
// src field is missing in tgt field
2111+
mismatches = append(mismatches, g.F("source field '%s' is missing in target table", colName))
2112+
continue
2113+
}
2114+
if quote {
2115+
newCol.Name = conn.Self().Quote(newCol.Name)
2116+
} else {
2117+
newCol.Name = conn.Self().Unquote(newCol.Name)
2118+
}
2119+
newCols = append(newCols, *newCol)
20782120
}
2079-
newCols = append(newCols, *newCol)
20802121
}
20812122

20822123
if len(mismatches) > 0 {

0 commit comments

Comments
 (0)