mirror of
https://github.com/go-gitea/gitea.git
synced 2024-09-01 14:56:30 +00:00
race the tcp connections (#6898)
Signed-off-by: abhishek818 <abhishekguptaatweb17@gmail.com>
This commit is contained in:
parent
88ca8fabea
commit
f2c4cae867
@ -112,50 +112,68 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) {
|
||||
func dial(source *Source) (*ldap.Conn, error) {
|
||||
log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify)
|
||||
|
||||
ldap.DefaultTimeout = time.Second * 15
|
||||
ldap.DefaultTimeout = time.Second * 10
|
||||
// Remove any extra spaces in HostList string
|
||||
tempHostList := strings.ReplaceAll(source.HostList, " ", "")
|
||||
// HostList is a list of hosts separated by commas
|
||||
hostList := strings.Split(tempHostList, ",")
|
||||
// hostList := strings.Split(source.HostList, ",")
|
||||
|
||||
type result struct {
|
||||
conn *ldap.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
results := make(chan result, len(hostList))
|
||||
|
||||
for _, host := range hostList {
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: host,
|
||||
InsecureSkipVerify: source.SkipVerify,
|
||||
}
|
||||
go func(host string) {
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: host,
|
||||
InsecureSkipVerify: source.SkipVerify,
|
||||
}
|
||||
|
||||
var conn *ldap.Conn
|
||||
var err error
|
||||
|
||||
if source.SecurityProtocol == SecurityProtocolLDAPS {
|
||||
conn, err = ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig)
|
||||
} else {
|
||||
conn, err = ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)))
|
||||
if err == nil && source.SecurityProtocol == SecurityProtocolStartTLS {
|
||||
err = conn.StartTLS(tlsConfig)
|
||||
}
|
||||
}
|
||||
|
||||
if source.SecurityProtocol == SecurityProtocolLDAPS {
|
||||
conn, err := ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig)
|
||||
if err != nil {
|
||||
// Connection failed, try again with the next host.
|
||||
conn.Close()
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
log.Trace("error during Dial for host %s: %w", host, err)
|
||||
continue
|
||||
results <- result{nil, err}
|
||||
return
|
||||
}
|
||||
|
||||
conn.SetTimeout(time.Second * 10)
|
||||
results <- result{conn, nil}
|
||||
}(host)
|
||||
}
|
||||
|
||||
return conn, err
|
||||
}
|
||||
|
||||
conn, err := ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)))
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
log.Trace("error during Dial for host %s: %w", host, err)
|
||||
continue
|
||||
}
|
||||
conn.SetTimeout(time.Second * 10)
|
||||
|
||||
if source.SecurityProtocol == SecurityProtocolStartTLS {
|
||||
if err = conn.StartTLS(tlsConfig); err != nil {
|
||||
conn.Close()
|
||||
log.Trace("error during StartTLS for host %s: %w", host, err)
|
||||
continue
|
||||
}
|
||||
for range hostList {
|
||||
r := <-results
|
||||
if r.err == nil {
|
||||
// Close other connections still in progress
|
||||
go func() {
|
||||
for range hostList {
|
||||
r := <-results
|
||||
if r.conn != nil {
|
||||
r.conn.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
return r.conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
// All servers were unreachable
|
||||
return nil, fmt.Errorf("dial failed for all provided servers: %s", hostList)
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user