race the tcp connections (#6898)

Signed-off-by: abhishek818 <abhishekguptaatweb17@gmail.com>
This commit is contained in:
abhishek818 2024-07-24 15:19:55 +05:30
parent 88ca8fabea
commit f2c4cae867

View File

@ -112,50 +112,68 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) {
func dial(source *Source) (*ldap.Conn, error) {
log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify)
ldap.DefaultTimeout = time.Second * 15
ldap.DefaultTimeout = time.Second * 10
// Remove any extra spaces in HostList string
tempHostList := strings.ReplaceAll(source.HostList, " ", "")
// HostList is a list of hosts separated by commas
hostList := strings.Split(tempHostList, ",")
// hostList := strings.Split(source.HostList, ",")
type result struct {
conn *ldap.Conn
err error
}
results := make(chan result, len(hostList))
for _, host := range hostList {
tlsConfig := &tls.Config{
ServerName: host,
InsecureSkipVerify: source.SkipVerify,
}
go func(host string) {
tlsConfig := &tls.Config{
ServerName: host,
InsecureSkipVerify: source.SkipVerify,
}
var conn *ldap.Conn
var err error
if source.SecurityProtocol == SecurityProtocolLDAPS {
conn, err = ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig)
} else {
conn, err = ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)))
if err == nil && source.SecurityProtocol == SecurityProtocolStartTLS {
err = conn.StartTLS(tlsConfig)
}
}
if source.SecurityProtocol == SecurityProtocolLDAPS {
conn, err := ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig)
if err != nil {
// Connection failed, try again with the next host.
conn.Close()
if conn != nil {
conn.Close()
}
log.Trace("error during Dial for host %s: %w", host, err)
continue
results <- result{nil, err}
return
}
conn.SetTimeout(time.Second * 10)
results <- result{conn, nil}
}(host)
}
return conn, err
}
conn, err := ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)))
if err != nil {
conn.Close()
log.Trace("error during Dial for host %s: %w", host, err)
continue
}
conn.SetTimeout(time.Second * 10)
if source.SecurityProtocol == SecurityProtocolStartTLS {
if err = conn.StartTLS(tlsConfig); err != nil {
conn.Close()
log.Trace("error during StartTLS for host %s: %w", host, err)
continue
}
for range hostList {
r := <-results
if r.err == nil {
// Close other connections still in progress
go func() {
for range hostList {
r := <-results
if r.conn != nil {
r.conn.Close()
}
}
}()
return r.conn, nil
}
}
// All servers were unreachable
return nil, fmt.Errorf("dial failed for all provided servers: %s", hostList)
}