From 511298e4524ac31c36d6a15f17a4789c39886a06 Mon Sep 17 00:00:00 2001
From: wxiaoguang <wxiaoguang@gmail.com>
Date: Fri, 23 Feb 2024 01:07:41 +0800
Subject: [PATCH] Use general token signing secret (#29205) (#29325)

Backport #29205 (including #29172)

Use a clearly defined "signing secret" for token signing.
---
 cmd/generate.go                              |  2 +-
 modules/base/tool.go                         |  2 +-
 modules/context/context.go                   |  3 +-
 modules/generate/generate.go                 | 24 +++++++-----
 modules/generate/generate_test.go            | 34 +++++++++++++++++
 modules/setting/lfs.go                       | 23 +++++-------
 modules/setting/oauth2.go                    | 39 +++++++++++++++-----
 modules/setting/oauth2_test.go               | 34 +++++++++++++++++
 modules/util/util.go                         | 11 ------
 modules/util/util_test.go                    | 14 -------
 routers/install/install.go                   |  2 +-
 services/auth/source/oauth2/jwtsigningkey.go |  8 +---
 services/packages/auth.go                    |  4 +-
 13 files changed, 130 insertions(+), 70 deletions(-)
 create mode 100644 modules/generate/generate_test.go
 create mode 100644 modules/setting/oauth2_test.go

diff --git a/cmd/generate.go b/cmd/generate.go
index 5922617217..34145af24d 100644
--- a/cmd/generate.go
+++ b/cmd/generate.go
@@ -70,7 +70,7 @@ func runGenerateInternalToken(c *cli.Context) error {
 }
 
 func runGenerateLfsJwtSecret(c *cli.Context) error {
-	_, jwtSecretBase64, err := generate.NewJwtSecretBase64()
+	_, jwtSecretBase64, err := generate.NewJwtSecretWithBase64()
 	if err != nil {
 		return err
 	}
diff --git a/modules/base/tool.go b/modules/base/tool.go
index 71dcb83fb4..971b55bf72 100644
--- a/modules/base/tool.go
+++ b/modules/base/tool.go
@@ -129,7 +129,7 @@ func CreateTimeLimitCode(data string, minutes int, startInf any) string {
 
 	// create sha1 encode string
 	sh := sha1.New()
-	_, _ = sh.Write([]byte(fmt.Sprintf("%s%s%s%s%d", data, setting.SecretKey, startStr, endStr, minutes)))
+	_, _ = sh.Write([]byte(fmt.Sprintf("%s%s%s%s%d", data, hex.EncodeToString(setting.GetGeneralTokenSigningSecret()), startStr, endStr, minutes)))
 	encoded := hex.EncodeToString(sh.Sum(nil))
 
 	code := fmt.Sprintf("%s%06d%s", startStr, minutes, encoded)
diff --git a/modules/context/context.go b/modules/context/context.go
index 8a94e958b5..58b12796e4 100644
--- a/modules/context/context.go
+++ b/modules/context/context.go
@@ -6,6 +6,7 @@ package context
 
 import (
 	"context"
+	"encoding/hex"
 	"html"
 	"html/template"
 	"io"
@@ -134,7 +135,7 @@ func NewWebContext(base *Base, render Render, session session.Store) *Context {
 func Contexter() func(next http.Handler) http.Handler {
 	rnd := templates.HTMLRenderer()
 	csrfOpts := CsrfOptions{
-		Secret:         setting.SecretKey,
+		Secret:         hex.EncodeToString(setting.GetGeneralTokenSigningSecret()),
 		Cookie:         setting.CSRFCookieName,
 		SetCookie:      true,
 		Secure:         setting.SessionConfig.Secure,
diff --git a/modules/generate/generate.go b/modules/generate/generate.go
index ee3c76059b..2d9a3dd902 100644
--- a/modules/generate/generate.go
+++ b/modules/generate/generate.go
@@ -7,6 +7,7 @@ package generate
 import (
 	"crypto/rand"
 	"encoding/base64"
+	"fmt"
 	"io"
 	"time"
 
@@ -38,19 +39,24 @@ func NewInternalToken() (string, error) {
 	return internalToken, nil
 }
 
-// NewJwtSecret generates a new value intended to be used for JWT secrets.
-func NewJwtSecret() ([]byte, error) {
-	bytes := make([]byte, 32)
-	_, err := io.ReadFull(rand.Reader, bytes)
-	if err != nil {
+const defaultJwtSecretLen = 32
+
+// DecodeJwtSecretBase64 decodes a base64 encoded jwt secret into bytes, and check its length
+func DecodeJwtSecretBase64(src string) ([]byte, error) {
+	encoding := base64.RawURLEncoding
+	decoded := make([]byte, encoding.DecodedLen(len(src))+3)
+	if n, err := encoding.Decode(decoded, []byte(src)); err != nil {
 		return nil, err
+	} else if n != defaultJwtSecretLen {
+		return nil, fmt.Errorf("invalid base64 decoded length: %d, expects: %d", n, defaultJwtSecretLen)
 	}
-	return bytes, nil
+	return decoded[:defaultJwtSecretLen], nil
 }
 
-// NewJwtSecretBase64 generates a new base64 encoded value intended to be used for JWT secrets.
-func NewJwtSecretBase64() ([]byte, string, error) {
-	bytes, err := NewJwtSecret()
+// NewJwtSecretWithBase64 generates a jwt secret with its base64 encoded value intended to be used for saving into config file
+func NewJwtSecretWithBase64() ([]byte, string, error) {
+	bytes := make([]byte, defaultJwtSecretLen)
+	_, err := io.ReadFull(rand.Reader, bytes)
 	if err != nil {
 		return nil, "", err
 	}
diff --git a/modules/generate/generate_test.go b/modules/generate/generate_test.go
new file mode 100644
index 0000000000..af640a60c1
--- /dev/null
+++ b/modules/generate/generate_test.go
@@ -0,0 +1,34 @@
+// Copyright 2024 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package generate
+
+import (
+	"encoding/base64"
+	"strings"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestDecodeJwtSecretBase64(t *testing.T) {
+	_, err := DecodeJwtSecretBase64("abcd")
+	assert.ErrorContains(t, err, "invalid base64 decoded length")
+	_, err = DecodeJwtSecretBase64(strings.Repeat("a", 64))
+	assert.ErrorContains(t, err, "invalid base64 decoded length")
+
+	str32 := strings.Repeat("x", 32)
+	encoded32 := base64.RawURLEncoding.EncodeToString([]byte(str32))
+	decoded32, err := DecodeJwtSecretBase64(encoded32)
+	assert.NoError(t, err)
+	assert.Equal(t, str32, string(decoded32))
+}
+
+func TestNewJwtSecretWithBase64(t *testing.T) {
+	secret, encoded, err := NewJwtSecretWithBase64()
+	assert.NoError(t, err)
+	assert.Len(t, secret, 32)
+	decoded, err := DecodeJwtSecretBase64(encoded)
+	assert.NoError(t, err)
+	assert.Equal(t, secret, decoded)
+}
diff --git a/modules/setting/lfs.go b/modules/setting/lfs.go
index a5ea537cef..2034ef782c 100644
--- a/modules/setting/lfs.go
+++ b/modules/setting/lfs.go
@@ -4,22 +4,19 @@
 package setting
 
 import (
-	"encoding/base64"
 	"fmt"
 	"time"
 
 	"code.gitea.io/gitea/modules/generate"
-	"code.gitea.io/gitea/modules/util"
 )
 
 // LFS represents the configuration for Git LFS
 var LFS = struct {
-	StartServer     bool          `ini:"LFS_START_SERVER"`
-	JWTSecretBase64 string        `ini:"LFS_JWT_SECRET"`
-	JWTSecretBytes  []byte        `ini:"-"`
-	HTTPAuthExpiry  time.Duration `ini:"LFS_HTTP_AUTH_EXPIRY"`
-	MaxFileSize     int64         `ini:"LFS_MAX_FILE_SIZE"`
-	LocksPagingNum  int           `ini:"LFS_LOCKS_PAGING_NUM"`
+	StartServer    bool          `ini:"LFS_START_SERVER"`
+	JWTSecretBytes []byte        `ini:"-"`
+	HTTPAuthExpiry time.Duration `ini:"LFS_HTTP_AUTH_EXPIRY"`
+	MaxFileSize    int64         `ini:"LFS_MAX_FILE_SIZE"`
+	LocksPagingNum int           `ini:"LFS_LOCKS_PAGING_NUM"`
 
 	Storage *Storage
 }{}
@@ -61,10 +58,10 @@ func loadLFSFrom(rootCfg ConfigProvider) error {
 		return nil
 	}
 
-	LFS.JWTSecretBase64 = loadSecret(rootCfg.Section("server"), "LFS_JWT_SECRET_URI", "LFS_JWT_SECRET")
-	LFS.JWTSecretBytes, err = util.Base64FixedDecode(base64.RawURLEncoding, []byte(LFS.JWTSecretBase64), 32)
+	jwtSecretBase64 := loadSecret(rootCfg.Section("server"), "LFS_JWT_SECRET_URI", "LFS_JWT_SECRET")
+	LFS.JWTSecretBytes, err = generate.DecodeJwtSecretBase64(jwtSecretBase64)
 	if err != nil {
-		LFS.JWTSecretBytes, LFS.JWTSecretBase64, err = generate.NewJwtSecretBase64()
+		LFS.JWTSecretBytes, jwtSecretBase64, err = generate.NewJwtSecretWithBase64()
 		if err != nil {
 			return fmt.Errorf("error generating JWT Secret for custom config: %v", err)
 		}
@@ -74,8 +71,8 @@ func loadLFSFrom(rootCfg ConfigProvider) error {
 		if err != nil {
 			return fmt.Errorf("error saving JWT Secret for custom config: %v", err)
 		}
-		rootCfg.Section("server").Key("LFS_JWT_SECRET").SetValue(LFS.JWTSecretBase64)
-		saveCfg.Section("server").Key("LFS_JWT_SECRET").SetValue(LFS.JWTSecretBase64)
+		rootCfg.Section("server").Key("LFS_JWT_SECRET").SetValue(jwtSecretBase64)
+		saveCfg.Section("server").Key("LFS_JWT_SECRET").SetValue(jwtSecretBase64)
 		if err := saveCfg.Save(); err != nil {
 			return fmt.Errorf("error saving JWT Secret for custom config: %v", err)
 		}
diff --git a/modules/setting/oauth2.go b/modules/setting/oauth2.go
index ab82393106..847e0f1ee8 100644
--- a/modules/setting/oauth2.go
+++ b/modules/setting/oauth2.go
@@ -4,13 +4,12 @@
 package setting
 
 import (
-	"encoding/base64"
 	"math"
 	"path/filepath"
+	"sync/atomic"
 
 	"code.gitea.io/gitea/modules/generate"
 	"code.gitea.io/gitea/modules/log"
-	"code.gitea.io/gitea/modules/util"
 )
 
 // OAuth2UsernameType is enum describing the way gitea 'name' should be generated from oauth2 data
@@ -98,7 +97,6 @@ var OAuth2 = struct {
 	RefreshTokenExpirationTime int64
 	InvalidateRefreshTokens    bool
 	JWTSigningAlgorithm        string `ini:"JWT_SIGNING_ALGORITHM"`
-	JWTSecretBase64            string `ini:"JWT_SECRET"`
 	JWTSigningPrivateKeyFile   string `ini:"JWT_SIGNING_PRIVATE_KEY_FILE"`
 	MaxTokenLength             int
 	DefaultApplications        []string
@@ -123,29 +121,50 @@ func loadOAuth2From(rootCfg ConfigProvider) {
 		return
 	}
 
-	OAuth2.JWTSecretBase64 = loadSecret(rootCfg.Section("oauth2"), "JWT_SECRET_URI", "JWT_SECRET")
+	jwtSecretBase64 := loadSecret(rootCfg.Section("oauth2"), "JWT_SECRET_URI", "JWT_SECRET")
 
 	if !filepath.IsAbs(OAuth2.JWTSigningPrivateKeyFile) {
 		OAuth2.JWTSigningPrivateKeyFile = filepath.Join(AppDataPath, OAuth2.JWTSigningPrivateKeyFile)
 	}
 
 	if InstallLock {
-		if _, err := util.Base64FixedDecode(base64.RawURLEncoding, []byte(OAuth2.JWTSecretBase64), 32); err != nil {
-			key, err := generate.NewJwtSecret()
+		jwtSecretBytes, err := generate.DecodeJwtSecretBase64(jwtSecretBase64)
+		if err != nil {
+			jwtSecretBytes, jwtSecretBase64, err = generate.NewJwtSecretWithBase64()
 			if err != nil {
 				log.Fatal("error generating JWT secret: %v", err)
 			}
-
-			OAuth2.JWTSecretBase64 = base64.RawURLEncoding.EncodeToString(key)
 			saveCfg, err := rootCfg.PrepareSaving()
 			if err != nil {
 				log.Fatal("save oauth2.JWT_SECRET failed: %v", err)
 			}
-			rootCfg.Section("oauth2").Key("JWT_SECRET").SetValue(OAuth2.JWTSecretBase64)
-			saveCfg.Section("oauth2").Key("JWT_SECRET").SetValue(OAuth2.JWTSecretBase64)
+			rootCfg.Section("oauth2").Key("JWT_SECRET").SetValue(jwtSecretBase64)
+			saveCfg.Section("oauth2").Key("JWT_SECRET").SetValue(jwtSecretBase64)
 			if err := saveCfg.Save(); err != nil {
 				log.Fatal("save oauth2.JWT_SECRET failed: %v", err)
 			}
 		}
+		generalSigningSecret.Store(&jwtSecretBytes)
 	}
 }
+
+// generalSigningSecret is used as container for a []byte value
+// instead of an additional mutex, we use CompareAndSwap func to change the value thread save
+var generalSigningSecret atomic.Pointer[[]byte]
+
+func GetGeneralTokenSigningSecret() []byte {
+	old := generalSigningSecret.Load()
+	if old == nil || len(*old) == 0 {
+		jwtSecret, _, err := generate.NewJwtSecretWithBase64()
+		if err != nil {
+			log.Fatal("Unable to generate general JWT secret: %s", err.Error())
+		}
+		if generalSigningSecret.CompareAndSwap(old, &jwtSecret) {
+			// FIXME: in main branch, the signing token should be refactored (eg: one unique for LFS/OAuth2/etc ...)
+			log.Warn("OAuth2 is not enabled, unable to use a persistent signing secret, a new one is generated, which is not persistent between restarts and cluster nodes")
+			return jwtSecret
+		}
+		return *generalSigningSecret.Load()
+	}
+	return *old
+}
diff --git a/modules/setting/oauth2_test.go b/modules/setting/oauth2_test.go
new file mode 100644
index 0000000000..d822198619
--- /dev/null
+++ b/modules/setting/oauth2_test.go
@@ -0,0 +1,34 @@
+// Copyright 2024 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package setting
+
+import (
+	"testing"
+
+	"code.gitea.io/gitea/modules/generate"
+	"code.gitea.io/gitea/modules/test"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestGetGeneralSigningSecret(t *testing.T) {
+	// when there is no general signing secret, it should be generated, and keep the same value
+	assert.Nil(t, generalSigningSecret.Load())
+	s1 := GetGeneralTokenSigningSecret()
+	assert.NotNil(t, s1)
+	s2 := GetGeneralTokenSigningSecret()
+	assert.Equal(t, s1, s2)
+
+	// the config value should always override any pre-generated value
+	cfg, _ := NewConfigProviderFromData(`
+[oauth2]
+JWT_SECRET = BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
+`)
+	defer test.MockVariableValue(&InstallLock, true)()
+	loadOAuth2From(cfg)
+	actual := GetGeneralTokenSigningSecret()
+	expected, _ := generate.DecodeJwtSecretBase64("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")
+	assert.Len(t, actual, 32)
+	assert.EqualValues(t, expected, actual)
+}
diff --git a/modules/util/util.go b/modules/util/util.go
index c47931f6c9..0e5c6a4e64 100644
--- a/modules/util/util.go
+++ b/modules/util/util.go
@@ -6,7 +6,6 @@ package util
 import (
 	"bytes"
 	"crypto/rand"
-	"encoding/base64"
 	"fmt"
 	"math/big"
 	"strconv"
@@ -246,13 +245,3 @@ func ToFloat64(number any) (float64, error) {
 func ToPointer[T any](val T) *T {
 	return &val
 }
-
-func Base64FixedDecode(encoding *base64.Encoding, src []byte, length int) ([]byte, error) {
-	decoded := make([]byte, encoding.DecodedLen(len(src))+3)
-	if n, err := encoding.Decode(decoded, src); err != nil {
-		return nil, err
-	} else if n != length {
-		return nil, fmt.Errorf("invalid base64 decoded length: %d, expects: %d", n, length)
-	}
-	return decoded[:length], nil
-}
diff --git a/modules/util/util_test.go b/modules/util/util_test.go
index 8509d8aced..c5830ce01c 100644
--- a/modules/util/util_test.go
+++ b/modules/util/util_test.go
@@ -4,7 +4,6 @@
 package util
 
 import (
-	"encoding/base64"
 	"regexp"
 	"strings"
 	"testing"
@@ -234,16 +233,3 @@ func TestToPointer(t *testing.T) {
 	val123 := 123
 	assert.False(t, &val123 == ToPointer(val123))
 }
-
-func TestBase64FixedDecode(t *testing.T) {
-	_, err := Base64FixedDecode(base64.RawURLEncoding, []byte("abcd"), 32)
-	assert.ErrorContains(t, err, "invalid base64 decoded length")
-	_, err = Base64FixedDecode(base64.RawURLEncoding, []byte(strings.Repeat("a", 64)), 32)
-	assert.ErrorContains(t, err, "invalid base64 decoded length")
-
-	str32 := strings.Repeat("x", 32)
-	encoded32 := base64.RawURLEncoding.EncodeToString([]byte(str32))
-	decoded32, err := Base64FixedDecode(base64.RawURLEncoding, []byte(encoded32), 32)
-	assert.NoError(t, err)
-	assert.Equal(t, str32, string(decoded32))
-}
diff --git a/routers/install/install.go b/routers/install/install.go
index 185e4bf6bf..e021fc3541 100644
--- a/routers/install/install.go
+++ b/routers/install/install.go
@@ -407,7 +407,7 @@ func SubmitInstall(ctx *context.Context) {
 		cfg.Section("server").Key("LFS_START_SERVER").SetValue("true")
 		cfg.Section("lfs").Key("PATH").SetValue(form.LFSRootPath)
 		var lfsJwtSecret string
-		if _, lfsJwtSecret, err = generate.NewJwtSecretBase64(); err != nil {
+		if _, lfsJwtSecret, err = generate.NewJwtSecretWithBase64(); err != nil {
 			ctx.RenderWithErr(ctx.Tr("install.lfs_jwt_secret_failed", err), tplInstall, &form)
 			return
 		}
diff --git a/services/auth/source/oauth2/jwtsigningkey.go b/services/auth/source/oauth2/jwtsigningkey.go
index eca0b8b7e1..070fffe60f 100644
--- a/services/auth/source/oauth2/jwtsigningkey.go
+++ b/services/auth/source/oauth2/jwtsigningkey.go
@@ -300,7 +300,7 @@ func InitSigningKey() error {
 	case "HS384":
 		fallthrough
 	case "HS512":
-		key, err = loadSymmetricKey()
+		key = setting.GetGeneralTokenSigningSecret()
 	case "RS256":
 		fallthrough
 	case "RS384":
@@ -333,12 +333,6 @@ func InitSigningKey() error {
 	return nil
 }
 
-// loadSymmetricKey checks if the configured secret is valid.
-// If it is not valid, it will return an error.
-func loadSymmetricKey() (any, error) {
-	return util.Base64FixedDecode(base64.RawURLEncoding, []byte(setting.OAuth2.JWTSecretBase64), 32)
-}
-
 // loadOrCreateAsymmetricKey checks if the configured private key exists.
 // If it does not exist a new random key gets generated and saved on the configured path.
 func loadOrCreateAsymmetricKey() (any, error) {
diff --git a/services/packages/auth.go b/services/packages/auth.go
index 2f78b26f50..8263c28bed 100644
--- a/services/packages/auth.go
+++ b/services/packages/auth.go
@@ -33,7 +33,7 @@ func CreateAuthorizationToken(u *user_model.User) (string, error) {
 	}
 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
 
-	tokenString, err := token.SignedString([]byte(setting.SecretKey))
+	tokenString, err := token.SignedString(setting.GetGeneralTokenSigningSecret())
 	if err != nil {
 		return "", err
 	}
@@ -57,7 +57,7 @@ func ParseAuthorizationToken(req *http.Request) (int64, error) {
 		if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
 			return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
 		}
-		return []byte(setting.SecretKey), nil
+		return setting.GetGeneralTokenSigningSecret(), nil
 	})
 	if err != nil {
 		return 0, err