From 144cfe5720d28da4a3b9a0697234740d88d3b4c3 Mon Sep 17 00:00:00 2001
From: zeripath <art27@cantab.net>
Date: Fri, 5 Mar 2021 13:19:17 +0000
Subject: [PATCH] Fix race in local storage (#14888)

LocalStorage should only put completed files in position

Signed-off-by: Andrew Thornton <art27@cantab.net>
---
 modules/storage/local.go | 50 +++++++++++++++++++++++++++++++---------
 1 file changed, 39 insertions(+), 11 deletions(-)

diff --git a/modules/storage/local.go b/modules/storage/local.go
index 93af13ee33..982d2b88c6 100644
--- a/modules/storage/local.go
+++ b/modules/storage/local.go
@@ -7,6 +7,7 @@ package storage
 import (
 	"context"
 	"io"
+	"io/ioutil"
 	"net/url"
 	"os"
 	"path/filepath"
@@ -24,13 +25,15 @@ const LocalStorageType Type = "local"
 
 // LocalStorageConfig represents the configuration for a local storage
 type LocalStorageConfig struct {
-	Path string `ini:"PATH"`
+	Path          string `ini:"PATH"`
+	TemporaryPath string `ini:"TEMPORARY_PATH"`
 }
 
 // LocalStorage represents a local files storage
 type LocalStorage struct {
-	ctx context.Context
-	dir string
+	ctx    context.Context
+	dir    string
+	tmpdir string
 }
 
 // NewLocalStorage returns a local files
@@ -46,9 +49,14 @@ func NewLocalStorage(ctx context.Context, cfg interface{}) (ObjectStorage, error
 		return nil, err
 	}
 
+	if config.TemporaryPath == "" {
+		config.TemporaryPath = config.Path + "/tmp"
+	}
+
 	return &LocalStorage{
-		ctx: ctx,
-		dir: config.Path,
+		ctx:    ctx,
+		dir:    config.Path,
+		tmpdir: config.TemporaryPath,
 	}, nil
 }
 
@@ -64,17 +72,37 @@ func (l *LocalStorage) Save(path string, r io.Reader) (int64, error) {
 		return 0, err
 	}
 
-	// always override
-	if err := util.Remove(p); err != nil {
+	// Create a temporary file to save to
+	if err := os.MkdirAll(l.tmpdir, os.ModePerm); err != nil {
 		return 0, err
 	}
-
-	f, err := os.Create(p)
+	tmp, err := ioutil.TempFile(l.tmpdir, "upload-*")
 	if err != nil {
 		return 0, err
 	}
-	defer f.Close()
-	return io.Copy(f, r)
+	tmpRemoved := false
+	defer func() {
+		if !tmpRemoved {
+			_ = util.Remove(tmp.Name())
+		}
+	}()
+
+	n, err := io.Copy(tmp, r)
+	if err != nil {
+		return 0, err
+	}
+
+	if err := tmp.Close(); err != nil {
+		return 0, err
+	}
+
+	if err := os.Rename(tmp.Name(), p); err != nil {
+		return 0, err
+	}
+
+	tmpRemoved = true
+
+	return n, nil
 }
 
 // Stat returns the info of the file