From f2c4cae867536f6f2e2f8f3519691f933df2e388 Mon Sep 17 00:00:00 2001 From: abhishek818 Date: Wed, 24 Jul 2024 15:19:55 +0530 Subject: [PATCH] race the tcp connections (#6898) Signed-off-by: abhishek818 --- services/auth/source/ldap/source_search.go | 76 +++++++++++++--------- 1 file changed, 47 insertions(+), 29 deletions(-) diff --git a/services/auth/source/ldap/source_search.go b/services/auth/source/ldap/source_search.go index ad5ebe365c..5ad2247e33 100644 --- a/services/auth/source/ldap/source_search.go +++ b/services/auth/source/ldap/source_search.go @@ -112,50 +112,68 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) { func dial(source *Source) (*ldap.Conn, error) { log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify) - ldap.DefaultTimeout = time.Second * 15 + ldap.DefaultTimeout = time.Second * 10 // Remove any extra spaces in HostList string tempHostList := strings.ReplaceAll(source.HostList, " ", "") // HostList is a list of hosts separated by commas hostList := strings.Split(tempHostList, ",") - // hostList := strings.Split(source.HostList, ",") + + type result struct { + conn *ldap.Conn + err error + } + + results := make(chan result, len(hostList)) for _, host := range hostList { - tlsConfig := &tls.Config{ - ServerName: host, - InsecureSkipVerify: source.SkipVerify, - } + go func(host string) { + tlsConfig := &tls.Config{ + ServerName: host, + InsecureSkipVerify: source.SkipVerify, + } + + var conn *ldap.Conn + var err error + + if source.SecurityProtocol == SecurityProtocolLDAPS { + conn, err = ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig) + } else { + conn, err = ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port))) + if err == nil && source.SecurityProtocol == SecurityProtocolStartTLS { + err = conn.StartTLS(tlsConfig) + } + } - if source.SecurityProtocol == SecurityProtocolLDAPS { - conn, err := ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig) if err != nil { - // Connection failed, try again with the next host. - conn.Close() + if conn != nil { + conn.Close() + } log.Trace("error during Dial for host %s: %w", host, err) - continue + results <- result{nil, err} + return } + conn.SetTimeout(time.Second * 10) + results <- result{conn, nil} + }(host) + } - return conn, err - } - - conn, err := ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port))) - if err != nil { - conn.Close() - log.Trace("error during Dial for host %s: %w", host, err) - continue - } - conn.SetTimeout(time.Second * 10) - - if source.SecurityProtocol == SecurityProtocolStartTLS { - if err = conn.StartTLS(tlsConfig); err != nil { - conn.Close() - log.Trace("error during StartTLS for host %s: %w", host, err) - continue - } + for range hostList { + r := <-results + if r.err == nil { + // Close other connections still in progress + go func() { + for range hostList { + r := <-results + if r.conn != nil { + r.conn.Close() + } + } + }() + return r.conn, nil } } - // All servers were unreachable return nil, fmt.Errorf("dial failed for all provided servers: %s", hostList) }