Support multiple LDAP servers in a auth source (#6898)

Signed-off-by: abhishek818 <abhishekguptaatweb17@gmail.com>
This commit is contained in:
abhishek818 2024-07-17 15:10:12 +05:30
parent b4ccef3dee
commit f954681e69
5 changed files with 49 additions and 29 deletions

View File

@ -207,7 +207,7 @@ func parseLdapConfig(c *cli.Context, config *ldap.Source) error {
config.Name = c.String("name") config.Name = c.String("name")
} }
if c.IsSet("host") { if c.IsSet("host") {
config.Host = c.String("host") config.HostList = c.String("hostlist")
} }
if c.IsSet("port") { if c.IsSet("port") {
config.Port = c.Int("port") config.Port = c.Int("port")

View File

@ -59,7 +59,7 @@ func TestAddLdapBindDn(t *testing.T) {
IsSyncEnabled: true, IsSyncEnabled: true,
Cfg: &ldap.Source{ Cfg: &ldap.Source{
Name: "ldap (via Bind DN) source full", Name: "ldap (via Bind DN) source full",
Host: "ldap-bind-server full", HostList: "ldap-bind-server full",
Port: 9876, Port: 9876,
SecurityProtocol: ldap.SecurityProtocol(1), SecurityProtocol: ldap.SecurityProtocol(1),
SkipVerify: true, SkipVerify: true,
@ -99,7 +99,7 @@ func TestAddLdapBindDn(t *testing.T) {
IsActive: true, IsActive: true,
Cfg: &ldap.Source{ Cfg: &ldap.Source{
Name: "ldap (via Bind DN) source min", Name: "ldap (via Bind DN) source min",
Host: "ldap-bind-server min", HostList: "ldap-bind-server min",
Port: 1234, Port: 1234,
SecurityProtocol: ldap.SecurityProtocol(0), SecurityProtocol: ldap.SecurityProtocol(0),
UserBase: "ou=Users,dc=min-domain-bind,dc=org", UserBase: "ou=Users,dc=min-domain-bind,dc=org",
@ -280,7 +280,7 @@ func TestAddLdapSimpleAuth(t *testing.T) {
IsActive: false, IsActive: false,
Cfg: &ldap.Source{ Cfg: &ldap.Source{
Name: "ldap (simple auth) source full", Name: "ldap (simple auth) source full",
Host: "ldap-simple-server full", HostList: "ldap-simple-server full",
Port: 987, Port: 987,
SecurityProtocol: ldap.SecurityProtocol(2), SecurityProtocol: ldap.SecurityProtocol(2),
SkipVerify: true, SkipVerify: true,
@ -317,7 +317,7 @@ func TestAddLdapSimpleAuth(t *testing.T) {
IsActive: true, IsActive: true,
Cfg: &ldap.Source{ Cfg: &ldap.Source{
Name: "ldap (simple auth) source min", Name: "ldap (simple auth) source min",
Host: "ldap-simple-server min", HostList: "ldap-simple-server min",
Port: 123, Port: 123,
SecurityProtocol: ldap.SecurityProtocol(0), SecurityProtocol: ldap.SecurityProtocol(0),
UserDN: "cn=%s,ou=Users,dc=min-domain-simple,dc=org", UserDN: "cn=%s,ou=Users,dc=min-domain-simple,dc=org",
@ -526,7 +526,7 @@ func TestUpdateLdapBindDn(t *testing.T) {
IsSyncEnabled: true, IsSyncEnabled: true,
Cfg: &ldap.Source{ Cfg: &ldap.Source{
Name: "ldap (via Bind DN) source full", Name: "ldap (via Bind DN) source full",
Host: "ldap-bind-server full", HostList: "ldap-bind-server full",
Port: 9876, Port: 9876,
SecurityProtocol: ldap.SecurityProtocol(1), SecurityProtocol: ldap.SecurityProtocol(1),
SkipVerify: true, SkipVerify: true,
@ -630,7 +630,7 @@ func TestUpdateLdapBindDn(t *testing.T) {
authSource: &auth.Source{ authSource: &auth.Source{
Type: auth.LDAP, Type: auth.LDAP,
Cfg: &ldap.Source{ Cfg: &ldap.Source{
Host: "ldap-server", HostList: "ldap-server",
}, },
}, },
}, },
@ -978,7 +978,7 @@ func TestUpdateLdapSimpleAuth(t *testing.T) {
IsActive: false, IsActive: false,
Cfg: &ldap.Source{ Cfg: &ldap.Source{
Name: "ldap (simple auth) source full", Name: "ldap (simple auth) source full",
Host: "ldap-simple-server full", HostList: "ldap-simple-server full",
Port: 987, Port: 987,
SecurityProtocol: ldap.SecurityProtocol(2), SecurityProtocol: ldap.SecurityProtocol(2),
SkipVerify: true, SkipVerify: true,
@ -1078,7 +1078,7 @@ func TestUpdateLdapSimpleAuth(t *testing.T) {
authSource: &auth.Source{ authSource: &auth.Source{
Type: auth.DLDAP, Type: auth.DLDAP,
Cfg: &ldap.Source{ Cfg: &ldap.Source{
Host: "ldap-server", HostList: "ldap-server",
}, },
}, },
}, },

View File

@ -121,7 +121,7 @@ func parseLDAPConfig(form forms.AuthenticationForm) *ldap.Source {
} }
return &ldap.Source{ return &ldap.Source{
Name: form.Name, Name: form.Name,
Host: form.Host, HostList: form.Host,
Port: form.Port, Port: form.Port,
SecurityProtocol: ldap.SecurityProtocol(form.SecurityProtocol), SecurityProtocol: ldap.SecurityProtocol(form.SecurityProtocol),
SkipVerify: form.SkipVerify, SkipVerify: form.SkipVerify,

View File

@ -25,7 +25,7 @@ import (
// Source Basic LDAP authentication service // Source Basic LDAP authentication service
type Source struct { type Source struct {
Name string // canonical name (ie. corporate.ad) Name string // canonical name (ie. corporate.ad)
Host string // LDAP host HostList string // list containing LDAP host(s)
Port int // port number Port int // port number
SecurityProtocol SecurityProtocol SecurityProtocol SecurityProtocol
SkipVerify bool SkipVerify bool

View File

@ -10,6 +10,7 @@ import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"time"
"code.gitea.io/gitea/modules/container" "code.gitea.io/gitea/modules/container"
"code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/log"
@ -111,28 +112,47 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) {
func dial(source *Source) (*ldap.Conn, error) { func dial(source *Source) (*ldap.Conn, error) {
log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify) log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify)
tlsConfig := &tls.Config{ ldap.DefaultTimeout = time.Second * 15
ServerName: source.Host, // HostList is a list of hosts separated by commas
InsecureSkipVerify: source.SkipVerify, hostList := strings.Split(source.HostList, ",")
}
if source.SecurityProtocol == SecurityProtocolLDAPS { for _, host := range hostList {
return ldap.DialTLS("tcp", net.JoinHostPort(source.Host, strconv.Itoa(source.Port)), tlsConfig) tlsConfig := &tls.Config{
} ServerName: host,
InsecureSkipVerify: source.SkipVerify,
}
conn, err := ldap.Dial("tcp", net.JoinHostPort(source.Host, strconv.Itoa(source.Port))) if source.SecurityProtocol == SecurityProtocolLDAPS {
if err != nil { conn, err := ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig)
return nil, fmt.Errorf("error during Dial: %w", err)
}
if source.SecurityProtocol == SecurityProtocolStartTLS { if err != nil {
if err = conn.StartTLS(tlsConfig); err != nil { // Connection failed, try again with the next host.
conn.Close() log.Trace("error during Dial for host %s: %w", host, err)
return nil, fmt.Errorf("error during StartTLS: %w", err) continue
}
conn.SetTimeout(time.Second * 10)
return conn, err
}
conn, err := ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)))
if err != nil {
log.Trace("error during Dial for host %s: %w", host, err)
continue
}
conn.SetTimeout(time.Second * 10)
if source.SecurityProtocol == SecurityProtocolStartTLS {
if err = conn.StartTLS(tlsConfig); err != nil {
conn.Close()
log.Trace("error during StartTLS for host %s: %w", host, err)
continue
}
} }
} }
return conn, nil // All servers were unreachable
return nil, fmt.Errorf("dial failed for all provided servers: %s", hostList)
} }
func bindUser(l *ldap.Conn, userDN, passwd string) error { func bindUser(l *ldap.Conn, userDN, passwd string) error {
@ -257,7 +277,7 @@ func (source *Source) SearchEntry(name, passwd string, directBind bool) *SearchR
} }
l, err := dial(source) l, err := dial(source)
if err != nil { if err != nil {
log.Error("LDAP Connect error, %s:%v", source.Host, err) log.Error("LDAP Connect error, %s:%v", source.HostList, err)
source.Enabled = false source.Enabled = false
return nil return nil
} }
@ -421,7 +441,7 @@ func (source *Source) UsePagedSearch() bool {
func (source *Source) SearchEntries() ([]*SearchResult, error) { func (source *Source) SearchEntries() ([]*SearchResult, error) {
l, err := dial(source) l, err := dial(source)
if err != nil { if err != nil {
log.Error("LDAP Connect error, %s:%v", source.Host, err) log.Error("LDAP Connect error, %s:%v", source.HostList, err)
source.Enabled = false source.Enabled = false
return nil, err return nil, err
} }