feat: improve code import and commnet (#18)

* feat: improve code import and commnet

* feat: improve code

* fix: modify err

* fix: modify table  to
This commit is contained in:
houseme 2023-02-15 17:13:54 +08:00 committed by GitHub
parent 225b7eed74
commit e60a42c400
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 269 additions and 157 deletions

View File

@ -11,29 +11,32 @@ package controller
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
"ohurlshortener/core"
"ohurlshortener/service"
"ohurlshortener/utils"
"ohurlshortener/utils/export"
"strconv"
"strings"
"time"
"github.com/dchest/captcha"
"github.com/gin-gonic/gin"
)
const (
DEFAULT_PAGE_NUM = 1
DEFAULT_PAGE_SIZE = 20
DefaultPageNum = 1
DefaultPageSize = 20
)
// LoginPage 登录页面
func LoginPage(c *gin.Context) {
c.HTML(http.StatusOK, "login.html", gin.H{
"title": "登录 - ohUrlShortener",
})
}
// DoLogin 登录
func DoLogin(c *gin.Context) {
account := c.PostForm("account")
password := c.PostForm("password")
@ -56,7 +59,7 @@ func DoLogin(c *gin.Context) {
return
}
//验证码有效性验证
// 验证码有效性验证
if !captcha.VerifyString(captchaId, captchaText) {
c.HTML(http.StatusOK, "login.html", gin.H{
"title": "错误 - ohUrlShortener",
@ -65,7 +68,7 @@ func DoLogin(c *gin.Context) {
return
}
//用户名密码有效性验证
// 用户名密码有效性验证
loginUser, err := service.Login(account, password)
if err != nil || loginUser.IsEmpty() {
c.HTML(http.StatusOK, "login.html", gin.H{
@ -75,7 +78,7 @@ func DoLogin(c *gin.Context) {
return
}
//Write Cookie to browser
// Write Cookie to browser
cValue, err := AdminCookieValue(loginUser)
if err != nil {
c.HTML(http.StatusOK, "login.html", gin.H{
@ -89,21 +92,25 @@ func DoLogin(c *gin.Context) {
c.Redirect(http.StatusFound, "/admin/dashboard")
}
// DoLogout 登出
func DoLogout(c *gin.Context) {
c.SetCookie("ohUrlShortenerAdmin", "", -1, "/", "", false, true)
c.SetCookie("ohUrlShortenerCookie", "", -1, "/", "", false, true)
c.Redirect(http.StatusFound, "/login")
}
// ServeCaptchaImage 生成验证码
func ServeCaptchaImage(c *gin.Context) {
captcha.Server(200, 45).ServeHTTP(c.Writer, c.Request)
}
// RequestCaptchaImage 请求验证码
func RequestCaptchaImage(c *gin.Context) {
imageId := captcha.New()
c.JSON(http.StatusOK, core.ResultJsonSuccessWithData(imageId))
}
// ChangeState 修改状态
func ChangeState(c *gin.Context) {
destUrl := c.PostForm("dest_url")
enable := c.PostForm("enable")
@ -128,6 +135,7 @@ func ChangeState(c *gin.Context) {
c.JSON(http.StatusOK, core.ResultJsonSuccessWithData(result))
}
// DeleteShortUrl 删除短链接
func DeleteShortUrl(c *gin.Context) {
url := c.PostForm("short_url")
if utils.EmptyString(strings.TrimSpace(url)) {
@ -144,6 +152,7 @@ func DeleteShortUrl(c *gin.Context) {
c.JSON(http.StatusOK, core.ResultJsonSuccess())
}
// GenerateShortUrl 生成短链接
func GenerateShortUrl(c *gin.Context) {
destUrl := c.PostForm("dest_url")
memo := c.PostForm("memo")
@ -165,17 +174,18 @@ func GenerateShortUrl(c *gin.Context) {
c.JSON(http.StatusOK, core.ResultJsonSuccessWithData(json))
}
// StatsPage 统计页面
func StatsPage(c *gin.Context) {
url := c.DefaultQuery("url", "")
strPage := c.DefaultQuery("page", strconv.Itoa(DEFAULT_PAGE_NUM))
strSize := c.DefaultQuery("size", strconv.Itoa(DEFAULT_PAGE_SIZE))
strPage := c.DefaultQuery("page", strconv.Itoa(DefaultPageNum))
strSize := c.DefaultQuery("size", strconv.Itoa(DefaultPageSize))
page, err := strconv.Atoi(strPage)
if err != nil {
page = DEFAULT_PAGE_NUM
page = DefaultPageNum
}
size, err := strconv.Atoi(strSize)
if err != nil {
size = DEFAULT_PAGE_SIZE
size = DefaultPageSize
}
urls, err := service.GetPagedUrlIpCountStats(strings.TrimSpace(url), page, size)
c.HTML(http.StatusOK, "stats.html", gin.H{
@ -192,17 +202,18 @@ func StatsPage(c *gin.Context) {
})
}
// SearchStatsPage 查询统计页面
func SearchStatsPage(c *gin.Context) {
url := c.DefaultQuery("url", "")
strPage := c.DefaultQuery("page", strconv.Itoa(DEFAULT_PAGE_NUM))
strSize := c.DefaultQuery("size", strconv.Itoa(DEFAULT_PAGE_SIZE))
strPage := c.DefaultQuery("page", strconv.Itoa(DefaultPageNum))
strSize := c.DefaultQuery("size", strconv.Itoa(DefaultPageSize))
page, err := strconv.Atoi(strPage)
if err != nil {
page = DEFAULT_PAGE_NUM
page = DefaultPageNum
}
size, err := strconv.Atoi(strSize)
if err != nil {
size = DEFAULT_PAGE_SIZE
size = DefaultPageSize
}
urls, err := service.GetPagedUrlIpCountStats(strings.TrimSpace(url), page, size)
c.HTML(http.StatusOK, "search_stats.html", gin.H{
@ -219,17 +230,18 @@ func SearchStatsPage(c *gin.Context) {
})
}
// UrlsPage 短链接列表页面
func UrlsPage(c *gin.Context) {
url := c.DefaultQuery("url", "")
strPage := c.DefaultQuery("page", strconv.Itoa(DEFAULT_PAGE_NUM))
strSize := c.DefaultQuery("size", strconv.Itoa(DEFAULT_PAGE_SIZE))
strPage := c.DefaultQuery("page", strconv.Itoa(DefaultPageNum))
strSize := c.DefaultQuery("size", strconv.Itoa(DefaultPageSize))
page, err := strconv.Atoi(strPage)
if err != nil {
page = DEFAULT_PAGE_NUM
page = DefaultPageNum
}
size, err := strconv.Atoi(strSize)
if err != nil {
size = DEFAULT_PAGE_SIZE
size = DefaultPageSize
}
urls, err := service.GetPagesShortUrls(strings.TrimSpace(url), page, size)
c.HTML(http.StatusOK, "urls.html", gin.H{
@ -246,19 +258,20 @@ func UrlsPage(c *gin.Context) {
})
}
// AccessLogsPage 访问日志页面
func AccessLogsPage(c *gin.Context) {
url := c.DefaultQuery("url", "")
strPage := c.DefaultQuery("page", strconv.Itoa(DEFAULT_PAGE_NUM))
strSize := c.DefaultQuery("size", strconv.Itoa(DEFAULT_PAGE_SIZE))
strPage := c.DefaultQuery("page", strconv.Itoa(DefaultPageNum))
strSize := c.DefaultQuery("size", strconv.Itoa(DefaultPageSize))
start := c.DefaultQuery("start", "")
end := c.DefaultQuery("end", "")
page, err := strconv.Atoi(strPage)
if err != nil {
page = DEFAULT_PAGE_NUM
page = DefaultPageNum
}
size, err := strconv.Atoi(strSize)
if err != nil {
size = DEFAULT_PAGE_SIZE
size = DefaultPageSize
}
totalCount, distinctIpCount, err := service.GetAccessLogsCount(strings.TrimSpace(url), start, end)
@ -281,6 +294,7 @@ func AccessLogsPage(c *gin.Context) {
})
}
// AccessLogsExport 导出访问日志
func AccessLogsExport(c *gin.Context) {
url := c.PostForm("url")
logs, err := service.GetAllAccessLogs(strings.TrimSpace(url))
@ -314,6 +328,7 @@ func AccessLogsExport(c *gin.Context) {
c.Data(http.StatusOK, "pplication/octet-stream", fileContent)
}
// DashboardPage 仪表盘页面
func DashboardPage(c *gin.Context) {
count, stats, err := service.GetSumOfUrlStats()
if err != nil {

View File

@ -11,18 +11,19 @@ package controller
import (
"fmt"
"net/http"
"strconv"
"strings"
"ohurlshortener/core"
"ohurlshortener/service"
"ohurlshortener/utils"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
// APINewAdmin
//
//Add new admin user
// Add new admin user
func APINewAdmin(ctx *gin.Context) {
account := ctx.PostForm("account")
password := ctx.PostForm("password")
@ -48,7 +49,7 @@ func APINewAdmin(ctx *gin.Context) {
// APIAdminUpdate
//
//Update password of given admin user
// Update password of given admin user
func APIAdminUpdate(ctx *gin.Context) {
account := ctx.Param("account")
password := ctx.PostForm("password")
@ -103,7 +104,7 @@ func APIUrlInfo(ctx *gin.Context) {
}
stat, err := service.GetShortUrlStats(strings.TrimSpace(url))
if utils.EmptyString(strings.TrimSpace(url)) {
if err != nil {
ctx.JSON(http.StatusInternalServerError, core.ResultJsonError(err.Error()))
return
}
@ -135,6 +136,7 @@ func APIUpdateUrl(ctx *gin.Context) {
ctx.JSON(http.StatusOK, core.ResultJsonSuccessWithData(res))
}
// APIDeleteUrl Delete Short Url
func APIDeleteUrl(ctx *gin.Context) {
url := ctx.Param("url")
if utils.EmptyString(strings.TrimSpace(url)) {

View File

@ -12,25 +12,26 @@ import (
"fmt"
"log"
"net/http"
"strconv"
"strings"
"ohurlshortener/core"
"ohurlshortener/service"
"ohurlshortener/storage"
"ohurlshortener/utils"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
const (
authoriationHeaderKey = "Authorization"
authoriationTypeBearer = "Bearer"
authorizationHeaderKey = "Authorization"
authorizationTypeBearer = "Bearer"
)
// APIAuthHandler Authorization for /api
func APIAuthHandler() gin.HandlerFunc {
return func(ctx *gin.Context) {
authHeader := ctx.GetHeader(authoriationHeaderKey)
authHeader := ctx.GetHeader(authorizationHeaderKey)
if utils.EmptyString(authHeader) {
ctx.AbortWithStatusJSON(http.StatusUnauthorized, core.ResultJsonUnauthorized("Authorization Header is empty"))
return
@ -42,7 +43,7 @@ func APIAuthHandler() gin.HandlerFunc {
return
}
if fields[0] != authoriationTypeBearer {
if fields[0] != authorizationTypeBearer {
ctx.AbortWithStatusJSON(http.StatusUnauthorized, core.ResultJsonUnauthorized("Unsupported Authorization Type"))
return
}
@ -63,6 +64,7 @@ func APIAuthHandler() gin.HandlerFunc {
}
}
// AdminCookieValue Generate cookie value for admin user
func AdminCookieValue(user core.User) (string, error) {
var result string
data, err := utils.Sha256Of(user.Account + "a=" + user.Password + "=e" + strconv.Itoa(user.ID))
@ -73,19 +75,20 @@ func AdminCookieValue(user core.User) (string, error) {
return utils.Base58Encode(data), nil
}
// AdminAuthHandler Authorization for /admin
func AdminAuthHandler() gin.HandlerFunc {
return func(c *gin.Context) {
user, err := c.Cookie("ohUrlShortenerAdmin")
if err != nil {
c.AbortWithStatus(http.StatusUnauthorized)
//c.AbortWithError(http.StatusFound, err)
// c.AbortWithError(http.StatusFound, err)
return
}
cookie, err := c.Cookie("ohUrlShortenerCookie")
if err != nil {
c.AbortWithStatus(http.StatusUnauthorized)
//c.Redirect(http.StatusFound, "/login")
// c.Redirect(http.StatusFound, "/login")
return
}
@ -122,9 +125,10 @@ func AdminAuthHandler() gin.HandlerFunc {
}
c.Next()
} //end of func
} // end of func
}
// WebLogFormatHandler Customized log format for web
func WebLogFormatHandler(server string) gin.HandlerFunc {
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
if !strings.HasPrefix(param.Path, "/assets") {
@ -139,10 +143,10 @@ func WebLogFormatHandler(server string) gin.HandlerFunc {
param.Request.UserAgent(),
param.ErrorMessage,
)
} //end of if
} // end of if
return ""
}) //end of formatter
} //end of func
}) // end of formatter
} // end of func
func validateToken(token string) (bool, error) {
users, err := storage.FindAllUsers()

View File

@ -10,12 +10,14 @@ package controller
import (
"net/http"
"ohurlshortener/service"
"ohurlshortener/utils"
"github.com/gin-gonic/gin"
)
// ShortUrlDetail 重定向到目标地址
func ShortUrlDetail(c *gin.Context) {
url := c.Param("url")
if utils.EmptyString(url) {
@ -50,6 +52,5 @@ func ShortUrlDetail(c *gin.Context) {
}
go service.NewAccessLog(url, c.ClientIP(), c.Request.UserAgent(), c.Request.Referer())
c.Redirect(http.StatusFound, destUrl)
}

View File

@ -13,6 +13,7 @@ import (
"time"
)
// AccessLog 访问日志
type AccessLog struct {
ID int64 `db:"id"`
ShortUrl string `db:"short_url"`

View File

@ -13,6 +13,7 @@ import (
"time"
)
// ResultJson 返回结果
type ResultJson struct {
Code int `json:"code"`
Status bool `json:"status"`
@ -21,6 +22,7 @@ type ResultJson struct {
Date time.Time `json:"date"`
}
// ResultJsonSuccess 返回成功结果
func ResultJsonSuccess() ResultJson {
return ResultJson{
Code: http.StatusOK,
@ -31,6 +33,7 @@ func ResultJsonSuccess() ResultJson {
}
}
// ResultJsonSuccessWithData 返回成功结果
func ResultJsonSuccessWithData(data interface{}) ResultJson {
return ResultJson{
Code: http.StatusOK,
@ -41,6 +44,7 @@ func ResultJsonSuccessWithData(data interface{}) ResultJson {
}
}
// ResultJsonError 返回错误结果
func ResultJsonError(message string) ResultJson {
return ResultJson{
Code: http.StatusInternalServerError,
@ -51,6 +55,7 @@ func ResultJsonError(message string) ResultJson {
}
}
// ResultJsonBadRequest 返回错误结果
func ResultJsonBadRequest(message string) ResultJson {
return ResultJson{
Code: http.StatusBadRequest,
@ -61,6 +66,7 @@ func ResultJsonBadRequest(message string) ResultJson {
}
}
// ResultJsonUnauthorized 返回错误结果
func ResultJsonUnauthorized(message string) ResultJson {
return ResultJson{
Code: http.StatusUnauthorized,

View File

@ -11,11 +11,13 @@ package core
import (
"database/sql"
"fmt"
"ohurlshortener/utils"
"reflect"
"time"
"ohurlshortener/utils"
)
// ShortUrl 短链接
type ShortUrl struct {
ID int64 `db:"id" json:"id"`
ShortUrl string `db:"short_url" json:"short_url"`
@ -25,10 +27,12 @@ type ShortUrl struct {
Memo sql.NullString `db:"memo" json:"memo"`
}
// IsEmpty 判断是否为空
func (url ShortUrl) IsEmpty() bool {
return reflect.DeepEqual(url, ShortUrl{})
}
// GenerateShortLink 生成短链接
func GenerateShortLink(initialLink string) (string, error) {
if utils.EmptyString(initialLink) {
return "", fmt.Errorf("empty string")

View File

@ -8,6 +8,7 @@
package core
// ShortUrlStats 短链接统计
type ShortUrlStats struct {
ShortUrl string `db:"short_url" json:"short_url"`
TodayCount int `db:"today_count" json:"today_count"`
@ -22,16 +23,19 @@ type ShortUrlStats struct {
DistinctTotalCount int `db:"d_total_count" json:"d_total_count"`
}
// Top25Url 短链接统计
type Top25Url struct {
ShortUrl
ShortUrlStats
}
// UrlIpCountStats 短链接统计
type UrlIpCountStats struct {
ShortUrl
ShortUrlStats
}
// StatsSum 短链接统计
type StatsSum struct {
Key string `db:"stats_key"`
Value int `db:"stats_value"`

View File

@ -2,12 +2,14 @@ package core
import "reflect"
// User 用户
type User struct {
ID int `db:"id"`
Account string `db:"account"`
Password string `db:"password"`
}
// IsEmpty 判断是否为空
func (user User) IsEmpty() bool {
return reflect.DeepEqual(user, User{})
}

3
go.sum
View File

@ -10,14 +10,12 @@ github.com/btcsuite/btcd v0.22.0-beta.0.20220111032746-97732e52810c/go.mod h1:tj
github.com/btcsuite/btcd v0.23.0 h1:V2/ZgjfDFIygAX3ZapeigkVBoVUtOJKSwrhZdlpSvaA=
github.com/btcsuite/btcd v0.23.0/go.mod h1:0QJIIN1wwIXF/3G/m87gIwGniDMDQqjVn4SZgnFpsYY=
github.com/btcsuite/btcd/btcec/v2 v2.1.0/go.mod h1:2VzYrv4Gm4apmbVVsSq5bqf1Ec8v56E48Vt0Y/umPgA=
github.com/btcsuite/btcd/btcec/v2 v2.1.3 h1:xM/n3yIhHAhHy04z4i43C8p4ehixJZMsnrVJkgl+MTE=
github.com/btcsuite/btcd/btcec/v2 v2.1.3/go.mod h1:ctjw4H1kknNJmRN4iP1R7bTQ+v3GJkZBd6mui8ZsAZE=
github.com/btcsuite/btcd/btcutil v1.0.0/go.mod h1:Uoxwv0pqYWhD//tfTiipkxNfdhG9UrLwaeswfjfdF0A=
github.com/btcsuite/btcd/btcutil v1.1.0/go.mod h1:5OapHB7A2hBBWLm48mmw4MOHNJCcUBTwmWH/0Jn8VHE=
github.com/btcsuite/btcd/btcutil v1.1.3 h1:xfbtw8lwpp0G6NwSHb+UE67ryTFHJAiNuipusjXSohQ=
github.com/btcsuite/btcd/btcutil v1.1.3/go.mod h1:UR7dsSJzJUfMmFiiLlIrMq1lS9jh9EdCV7FStZSnpi0=
github.com/btcsuite/btcd/chaincfg/chainhash v1.0.0/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc=
github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 h1:q0rUy8C/TYNBQS1+CGKw68tLOFYSNEs0TFnxxnS9+4U=
github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc=
github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA=
github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg=
@ -43,7 +41,6 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/dchest/captcha v1.0.0 h1:vw+bm/qMFvTgcjQlYVTuQBJkarm5R0YSsDKhm1HZI2o=
github.com/dchest/captcha v1.0.0/go.mod h1:7zoElIawLp7GUMLcj54K9kbw+jEyvz2K0FDdRRYhvWo=
github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1m5sE92cU+pd5Mcc=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs=
github.com/decred/dcrd/lru v1.0.0/go.mod h1:mxKOwFd7lFjN2GZYsiz/ecgqR6kkYAl+0pz0tEMk218=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=

56
main.go
View File

@ -16,13 +16,14 @@ import (
"io/fs"
"log"
"net/http"
"os"
"strings"
"time"
"ohurlshortener/controller"
"ohurlshortener/service"
"ohurlshortener/storage"
"ohurlshortener/utils"
"os"
"strings"
"time"
"github.com/Masterminds/sprig"
"github.com/gin-gonic/gin"
@ -30,20 +31,20 @@ import (
)
const (
WEB_READ_TIMEOUT = 15 * time.Second
WEB_WRITE_TIMEOUT = 15 * time.Second
WebReadTimeout = 15 * time.Second
WebWriteTimeout = 15 * time.Second
//清理 Redis 中的访问日志的时间间隔
ACCESS_LOG_CLEAN_INTERVAL = 1 * time.Minute
// AccessLogCleanInterval 清理 Redis 中的访问日志的时间间隔
AccessLogCleanInterval = 1 * time.Minute
// Top25 榜单计算间隔
TOP25_CALC_INTERVAL = 5 * time.Minute
// Top25CalcInterval Top25 榜单计算间隔
Top25CalcInterval = 5 * time.Minute
// 仪表盘页面中其他几个统计数据计算间隔
STATS_SUM_CALC_INTERVAL = 5 * time.Minute
// StatsSumCalcInterval 仪表盘页面中其他几个统计数据计算间隔
StatsSumCalcInterval = 5 * time.Minute
//全部访问日志分析统计的间隔
STATS_IP_SUM_CALC_INTERVAL = 30 * time.Minute
// StatsIpSumCalcInterval 全部访问日志分析统计的间隔
StatsIpSumCalcInterval = 30 * time.Minute
)
var (
@ -57,11 +58,10 @@ var (
)
func main() {
flag.StringVar(&cmdStart, "s", "", "starts ohUrlShortener service: admin | portal ")
flag.StringVar(&cmdConfig, "c", "config.ini", "config file path")
flag.Usage = func() {
fmt.Fprintf(os.Stdout, `ohUrlShortener version:%s
fmt.Fprintf(os.Stdout, `ohUrlShortener version:%s
Usage: ohurlshortener [-s admin|portal|<omit to start both>] [-c config_file_path]`, utils.Version)
flag.PrintDefaults()
}
@ -79,15 +79,15 @@ func main() {
portal := &http.Server{
Addr: fmt.Sprintf(":%d", utils.AppConfig.Port),
Handler: router01,
ReadTimeout: WEB_READ_TIMEOUT,
WriteTimeout: WEB_WRITE_TIMEOUT,
ReadTimeout: WebReadTimeout,
WriteTimeout: WebWriteTimeout,
}
admin := &http.Server{
Addr: fmt.Sprintf(":%d", utils.AppConfig.AdminPort),
Handler: router02,
ReadTimeout: WEB_READ_TIMEOUT,
WriteTimeout: WEB_WRITE_TIMEOUT,
ReadTimeout: WebReadTimeout,
WriteTimeout: WebWriteTimeout,
}
if strings.EqualFold("admin", strings.TrimSpace(cmdStart)) {
@ -106,7 +106,7 @@ func main() {
}
func initSettings() {
//Things MUST BE DONE before app starts
// Things MUST BE DONE before app starts
_, err := utils.InitConfig(cmdConfig)
utils.ExitOnError("Config initialization failed.", err)
@ -114,8 +114,8 @@ func initSettings() {
utils.ExitOnError("Redis initialization failed.", err)
_, err = storage.InitDatabaseService()
storage.CallProcedureStatsTop25() //recalculate when ohUrlShortener starts
storage.CallProcedureStatsSum() //recalculate when ohUrlShortener starts
storage.CallProcedureStatsTop25() // recalculate when ohUrlShortener starts
storage.CallProcedureStatsSum() // recalculate when ohUrlShortener starts
utils.ExitOnError("Database initialization failed.", err)
_, err = service.ReloadUrls()
@ -193,7 +193,7 @@ func initializeRoute01() (http.Handler, error) {
})
})
return router, nil
} //end of router01
} // end of router01
func initializeRoute02() (http.Handler, error) {
@ -258,10 +258,10 @@ func initializeRoute02() (http.Handler, error) {
})
})
return router, nil
} //end of router01
} // end of router01
func startTicker1() error {
redisTicker := time.NewTicker(ACCESS_LOG_CLEAN_INTERVAL)
redisTicker := time.NewTicker(AccessLogCleanInterval)
for range redisTicker.C {
log.Println("[StoreAccessLog] Start.")
if err := service.StoreAccessLogs(); err != nil {
@ -273,7 +273,7 @@ func startTicker1() error {
}
func startTicker2() error {
top25Ticker := time.NewTicker(TOP25_CALC_INTERVAL)
top25Ticker := time.NewTicker(Top25CalcInterval)
for range top25Ticker.C {
log.Println("[Top25Urls Ticker] Start.")
if err := storage.CallProcedureStatsTop25(); err != nil {
@ -285,7 +285,7 @@ func startTicker2() error {
}
func startTicker3() error {
statsIpSumTicker := time.NewTicker(STATS_IP_SUM_CALC_INTERVAL)
statsIpSumTicker := time.NewTicker(StatsIpSumCalcInterval)
for range statsIpSumTicker.C {
log.Println("[StatsIpSum Ticker] Start.")
if err := storage.CallProcedureStatsIPSum(); err != nil {
@ -297,7 +297,7 @@ func startTicker3() error {
}
func startTicker4() error {
statsSumTicker := time.NewTicker(STATS_SUM_CALC_INTERVAL)
statsSumTicker := time.NewTicker(StatsSumCalcInterval)
for range statsSumTicker.C {
log.Println("[StatsSum Ticker] Start.")
if err := storage.CallProcedureStatsSum(); err != nil {

View File

@ -13,26 +13,29 @@ import (
"encoding/json"
"fmt"
"log"
"time"
"ohurlshortener/core"
"ohurlshortener/storage"
"ohurlshortener/utils"
"time"
)
const access_logs_prefix = "OH_ACCESS_LOGS#"
const accessLogsPrefix = "OH_ACCESS_LOGS#"
// NewAccessLog 记录访问日志
func NewAccessLog(url string, ip string, useragent string, referer string) error {
var (
l = core.AccessLog{
ShortUrl: url,
AccessTime: time.Now(),
Ip: sql.NullString{String: ip, Valid: true},
UserAgent: sql.NullString{String: useragent, Valid: true},
}
logJson, _ = json.Marshal(l)
key = fmt.Sprintf("%s%s", accessLogsPrefix, utils.UserAgentIpHash(useragent, ip))
err = storage.RedisSet30m(key, logJson)
)
l := core.AccessLog{
ShortUrl: url,
AccessTime: time.Now(),
Ip: sql.NullString{String: ip, Valid: true},
UserAgent: sql.NullString{String: useragent, Valid: true},
}
logJson, _ := json.Marshal(l)
key := fmt.Sprintf("%s%s", access_logs_prefix, utils.UserAgentIpHash(useragent, ip))
err := storage.RedisSet30m(key, logJson)
if err != nil {
log.Println(err)
return utils.RaiseError("内部错误,请联系管理员")
@ -41,24 +44,25 @@ func NewAccessLog(url string, ip string, useragent string, referer string) error
return nil
}
// StoreAccessLogs 将访问日志存入数据库
func StoreAccessLogs() error {
keys, err := storage.RedisScan4Keys(access_logs_prefix + "*")
keys, err := storage.RedisScan4Keys(accessLogsPrefix + "*")
if err != nil {
log.Println(err)
return utils.RaiseError("内部错误,请联系管理员")
}
logs := []core.AccessLog{}
var logs []core.AccessLog
for _, k := range keys {
v, err := storage.RedisGetString(k)
if err != nil {
log.Printf("redis error for key %s", k)
continue
}
log := core.AccessLog{}
json.Unmarshal([]byte(v), &log)
logs = append(logs, log)
} //end of for
accessLog := core.AccessLog{}
json.Unmarshal([]byte(v), &accessLog)
logs = append(logs, accessLog)
} // end of for
err = storage.InsertAccessLogs(logs)
if err != nil {
@ -74,6 +78,7 @@ func StoreAccessLogs() error {
return nil
}
// GetPagedAccessLogs 获取分页访问日志
func GetPagedAccessLogs(url string, start, end string, page, size int) ([]core.AccessLog, error) {
if page < 1 || size < 1 {
return nil, nil
@ -86,10 +91,12 @@ func GetPagedAccessLogs(url string, start, end string, page, size int) ([]core.A
return allAccessLogs, nil
}
// GetAccessLogsCount 获取访问日志总数
func GetAccessLogsCount(url string, start, end string) (int, int, error) {
return storage.FindAccessLogsCount(url, start, end)
}
// GetAllAccessLogs 获取所有访问日志
func GetAllAccessLogs(url string) ([]core.AccessLog, error) {
allAccessLogs, err := storage.FindAllAccessLogsByUrl(url)
if err != nil {

View File

@ -9,11 +9,12 @@
package service
import (
"ohurlshortener/storage"
"ohurlshortener/utils"
"testing"
"github.com/bxcodec/faker/v3"
"ohurlshortener/storage"
"ohurlshortener/utils"
)
func TestStoreAccessLog(t *testing.T) {

View File

@ -14,6 +14,7 @@ import (
"ohurlshortener/utils"
)
// GetSumOfUrlStats 获取所有短链接的统计信息
func GetSumOfUrlStats() (int, core.ShortUrlStats, error) {
var (
totalCount int
@ -33,6 +34,7 @@ func GetSumOfUrlStats() (int, core.ShortUrlStats, error) {
return totalCount, result, nil
}
// GetShortUrlStats 获取单个短链接的统计信息
func GetShortUrlStats(url string) (core.ShortUrlStats, error) {
found, err := storage.GetUrlStats(url)
if err != nil {
@ -41,6 +43,7 @@ func GetShortUrlStats(url string) (core.ShortUrlStats, error) {
return found, nil
}
// GetTop25Url 获取访问量最高的 25 个短链接
func GetTop25Url() ([]core.Top25Url, error) {
found, err := storage.GetTop25()
if err != nil {
@ -49,6 +52,7 @@ func GetTop25Url() ([]core.Top25Url, error) {
return found, nil
}
// GetPagedUrlIpCountStats 获取单个短链接的 IP 访问量统计信息
func GetPagedUrlIpCountStats(url string, page int, size int) ([]core.UrlIpCountStats, error) {
if page < 1 || size < 1 {
return nil, nil

View File

@ -12,10 +12,11 @@ import (
"database/sql"
"fmt"
"log"
"time"
"ohurlshortener/core"
"ohurlshortener/storage"
"ohurlshortener/utils"
"time"
)
// ReloadUrls
@ -23,28 +24,28 @@ import (
// 从数据库中获取所有「有效」状态的短链接
// 并将其可以 key-> value 形式存入 Redis 中
func ReloadUrls() (bool, error) {
//把所有访问日志记录到数据库中
// 把所有访问日志记录到数据库中
err := StoreAccessLogs()
if err != nil {
log.Println(err)
return false, utils.RaiseError("内部错误,请联系管理员")
}
//找出所有已经配置好的短链接
// 找出所有已经配置好的短链接
urls, err := storage.FindAllShortUrls()
if err != nil {
log.Println(err)
return false, utils.RaiseError("内部错误,请联系管理员")
}
//清理 redis db
// 清理 redis db
err = storage.RedisFlushDB()
if err != nil {
log.Println(err)
return false, utils.RaiseError("内部错误,请联系管理员")
}
//将所有「有效」状态的短域名再次放入 Redis
// 将所有「有效」状态的短域名再次放入 Redis
for _, url := range urls {
if url.Valid {
err := storage.RedisSet4Ever(url.ShortUrl, url.DestUrl)
@ -53,16 +54,15 @@ func ReloadUrls() (bool, error) {
continue
}
}
} //end of for
} // end of for
return true, nil
}
// Search4ShortUrl
//
//从 Redis 中查询目标短链接是否存在
func Search4ShortUrl(shortUrl string) (string, error) {
destUrl, err := storage.RedisGetString(shortUrl)
if err != nil {
// 从 Redis 中查询目标短链接是否存在
func Search4ShortUrl(shortUrl string) (destUrl string, err error) {
if destUrl, err = storage.RedisGetString(shortUrl); err != nil {
log.Println(err)
return "", utils.RaiseError("内部错误,请联系管理员")
}
@ -71,7 +71,7 @@ func Search4ShortUrl(shortUrl string) (string, error) {
// GetPagesShortUrls
//
//获取分页的短链接信息
// 获取分页的短链接信息
func GetPagesShortUrls(url string, page int, size int) ([]core.ShortUrl, error) {
if page < 1 || size < 1 {
return nil, nil
@ -86,7 +86,7 @@ func GetPagesShortUrls(url string, page int, size int) ([]core.ShortUrl, error)
// GenerateShortUrl
//
//生成短链接
// 生成短链接
func GenerateShortUrl(destUrl string, memo string) (string, error) {
shortUrl, err := core.GenerateShortLink(destUrl)
if err != nil {
@ -133,7 +133,7 @@ func GenerateShortUrl(destUrl string, memo string) (string, error) {
// ChangeState
//
//禁用/启用短链接
// 禁用/启用短链接
func ChangeState(shortUrl string, enable bool) (bool, error) {
found, err := storage.FindShortUrl(shortUrl)
if err != nil {
@ -160,7 +160,7 @@ func ChangeState(shortUrl string, enable bool) (bool, error) {
return true, nil
}
//DeleteUrlAndAccessLogs 删除短链接以及对应的访问日志
// DeleteUrlAndAccessLogs 删除短链接以及对应的访问日志
func DeleteUrlAndAccessLogs(shortUrl string) error {
found, err := storage.FindShortUrl(shortUrl)
if err != nil {

View File

@ -3,17 +3,18 @@ package service
import (
"encoding/json"
"fmt"
"strings"
"ohurlshortener/core"
"ohurlshortener/storage"
"ohurlshortener/utils"
"strings"
)
const ADMIN_USER_PREFIX = "ohUrlShortenerAdmin#"
const ADMIN_COOKIE_PREFIX = "ohUrlShortenerCookie#"
const AdminUserPrefix = "ohUrlShortenerAdmin#"
const AdminCookiePrefix = "ohUrlShortenerCookie#"
// Login 登录
func Login(account string, pasword string) (core.User, error) {
var found core.User
found, err := GetUserByAccountFromRedis(account)
if err != nil {
@ -36,6 +37,7 @@ func Login(account string, pasword string) (core.User, error) {
return found, nil
}
// ReloadUsers 从数据库中获取所有用户
func ReloadUsers() error {
users, err := storage.FindAllUsers()
if err != nil {
@ -44,7 +46,7 @@ func ReloadUsers() error {
for _, user := range users {
jsonUser, _ := json.Marshal(user)
er := storage.RedisSet4Ever(ADMIN_USER_PREFIX+user.Account, jsonUser)
er := storage.RedisSet4Ever(AdminUserPrefix+user.Account, jsonUser)
if er != nil {
return er
}
@ -55,7 +57,7 @@ func ReloadUsers() error {
func GetUserByAccountFromRedis(account string) (core.User, error) {
var found core.User
foundUserStr, err := storage.RedisGetString(ADMIN_USER_PREFIX + account)
foundUserStr, err := storage.RedisGetString(AdminUserPrefix + account)
if err != nil {
return found, err
}

View File

@ -10,6 +10,7 @@ package storage
import (
"fmt"
"ohurlshortener/core"
"ohurlshortener/utils"
)
@ -20,9 +21,12 @@ func DeleteAccessLogs(shortUrl string) error {
}
func FindAccessLogs(shortUrl string) ([]core.AccessLog, error) {
found := []core.AccessLog{}
query := "SELECT * FROM public.access_logs l WHERE l.short_url = $1 ORDER BY l.id DESC"
err := DbSelect(query, &found, shortUrl)
var (
found []core.AccessLog
query = "SELECT * FROM public.access_logs l WHERE l.short_url = $1 ORDER BY l.id DESC"
err = DbSelect(query, &found, shortUrl)
)
return found, err
}
@ -31,8 +35,8 @@ func InsertAccessLogs(logs []core.AccessLog) error {
return nil
}
query := `INSERT INTO public.access_logs (short_url, access_time, ip, user_agent) VALUES(:short_url,:access_time,:ip,:user_agent)`
if len(logs) >= Max_Insert_Count {
logsSlice := splitLogsArray(logs, Max_Insert_Count)
if len(logs) >= MaxInsertCount {
logsSlice := splitLogsArray(logs, MaxInsertCount)
for _, slice := range logsSlice {
err := DbNamedExec(query, slice)
if err != nil {
@ -45,7 +49,7 @@ func InsertAccessLogs(logs []core.AccessLog) error {
// FindAccessLogsCount
//
// Find Access Logs Count and Unique IP Count
// # Find Access Logs Count and Unique IP Count
//
// First return value is total_count, Second return value is unique_ip_count ip count
func FindAccessLogsCount(url string, start, end string) (int, int, error) {
@ -68,9 +72,12 @@ func FindAccessLogsCount(url string, start, end string) (int, int, error) {
}
func FindAllAccessLogs(url string, start, end string, page, size int) ([]core.AccessLog, error) {
found := []core.AccessLog{}
offset := (page - 1) * size
query := `SELECT * FROM public.access_logs l WHERE 1=1 `
var (
found []core.AccessLog
offset = (page - 1) * size
query = `SELECT * FROM public.access_logs l WHERE 1=1 `
)
if !utils.EmptyString(url) {
query += fmt.Sprintf(` AND l.short_url = '%s'`, url)
}
@ -86,10 +93,13 @@ func FindAllAccessLogs(url string, start, end string, page, size int) ([]core.Ac
}
func FindAllAccessLogsByUrl(url string) ([]core.AccessLog, error) {
found := []core.AccessLog{}
query := "SELECT * FROM public.access_logs l ORDER BY l.id DESC"
var (
found []core.AccessLog
query = "SELECT * FROM public.access_logs l ORDER BY l.id DESC"
)
if !utils.EmptyString(url) {
query := "SELECT * FROM public.access_logs l WHERE l.short_url = $1 ORDER BY l.id DESC"
query = "SELECT * FROM public.access_logs l WHERE l.short_url = $1 ORDER BY l.id DESC"
err := DbSelect(query, &found, url)
return found, err
}

View File

@ -3,10 +3,11 @@ package storage
import (
"database/sql"
"math/rand"
"ohurlshortener/core"
"testing"
"time"
"ohurlshortener/core"
"github.com/bxcodec/faker/v3"
)

View File

@ -11,18 +11,21 @@ package storage
import (
"database/sql"
"fmt"
"ohurlshortener/utils"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"ohurlshortener/utils"
)
var dbService = &DatabaseService{}
// DatabaseService 数据库服务
type DatabaseService struct {
Connection *sqlx.DB
}
// InitDatabaseService 初始化数据库服务
func InitDatabaseService() (*DatabaseService, error) {
connStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
utils.DatabaseConfig.Host, utils.DatabaseConfig.Port, utils.DatabaseConfig.User,
@ -33,30 +36,32 @@ func InitDatabaseService() (*DatabaseService, error) {
}
conn.SetMaxOpenConns(utils.DatabaseConfig.MaxOpenConns)
conn.SetMaxIdleConns(utils.DatabaseConfig.MaxIdleConn)
conn.SetConnMaxLifetime(0) //always REUSE
conn.SetConnMaxLifetime(0) // always REUSE
dbService.Connection = conn
return dbService, nil
}
// DbNamedExec 执行带有命名参数的sql语句
func DbNamedExec(query string, args interface{}) error {
_, err := dbService.Connection.NamedExec(query, args)
return err
}
// DbExecTx 执行事务
func DbExecTx(query ...string) error {
tx := dbService.Connection.MustBegin()
for _, s := range query {
tx.MustExec(s)
} //end of for
} // end of for
err := tx.Commit()
if err != nil {
return tx.Rollback()
}
return nil
} //end of func
} // end of func
//
//func DbExecTx(query string, args ...interface{}) error {
// func DbExecTx(query string, args ...interface{}) error {
// tx, err := dbService.Connection.Begin()
// if err != nil {
// return err
@ -75,8 +80,9 @@ func DbExecTx(query ...string) error {
// }
//
// return nil
//}
// }
// DbGet 获取单条记录
func DbGet(query string, dest interface{}, args ...interface{}) error {
err := dbService.Connection.Get(dest, query, args...)
if err == sql.ErrNoRows {
@ -85,10 +91,12 @@ func DbGet(query string, dest interface{}, args ...interface{}) error {
return err
}
// DbSelect 获取多条记录
func DbSelect(query string, dest interface{}, args ...interface{}) error {
return dbService.Connection.Select(dest, query, args...)
}
// DbClose 关闭数据库连接
func DbClose() {
dbService.Connection.Close()
}

View File

@ -10,9 +10,10 @@ package storage
import (
"context"
"ohurlshortener/utils"
"time"
"ohurlshortener/utils"
"github.com/go-redis/redis/v8"
)
@ -21,10 +22,12 @@ var (
ctx = context.Background()
)
// RedisService Redis 服务
type RedisService struct {
redisClient *redis.Client
}
// InitRedisService 初始化 Redis 服务
func InitRedisService() (*RedisService, error) {
redisClient := redis.NewClient(&redis.Options{
Addr: utils.RedisConfig.Host,
@ -41,20 +44,24 @@ func InitRedisService() (*RedisService, error) {
return redisService, nil
}
// RedisSet 设置 Redis 键值对
func RedisSet(key string, value interface{}, ttl time.Duration) error {
return redisService.redisClient.Set(ctx, key, value, ttl).Err()
}
// RedisSet30m 设置 Redis 键值对,过期时间为 30 分钟
func RedisSet30m(key string, value interface{}) error {
return RedisSet(key, value, 30*time.Minute)
}
// RedisSet4Ever 设置 Redis 键值对,永不过期
func RedisSet4Ever(key string, value interface{}) error {
return RedisSet(key, value, redis.KeepTTL)
}
// RedisScan4Keys 获取 Redis 中所有以 prefix 开头的键
func RedisScan4Keys(prefix string) ([]string, error) {
keys := []string{}
var keys []string
sc := redisService.redisClient.Scan(ctx, 0, prefix, 0).Iterator()
for sc.Next(ctx) {
keys = append(keys, sc.Val())
@ -62,6 +69,7 @@ func RedisScan4Keys(prefix string) ([]string, error) {
return keys, nil
}
// RedisGetString 获取 Redis 中的字符串
func RedisGetString(key string) (string, error) {
result, err := redisService.redisClient.Get(ctx, key).Result()
if err == redis.Nil {
@ -70,10 +78,12 @@ func RedisGetString(key string) (string, error) {
return result, err
}
// RedisFlushDB 清空 Redis 中的所有键值对
func RedisFlushDB() error {
return redisService.redisClient.FlushDB(ctx).Err()
}
// RedisDelete 删除 Redis 中的键值对
func RedisDelete(key ...string) error {
if len(key) > 0 {
return redisService.redisClient.Del(ctx, key...).Err()

View File

@ -21,9 +21,10 @@ package storage
import (
"fmt"
"ohurlshortener/utils"
"time"
"ohurlshortener/utils"
"github.com/dchest/captcha"
)

View File

@ -21,10 +21,11 @@ package storage
import (
"log"
"ohurlshortener/utils"
"testing"
"time"
"ohurlshortener/utils"
"github.com/dchest/captcha"
)

View File

@ -5,20 +5,26 @@ import (
"ohurlshortener/utils"
)
// GetUrlStats 获取短链接的访问量统计信息
func GetUrlStats(url string) (core.ShortUrlStats, error) {
found := core.ShortUrlStats{}
query := `select * from public.url_ip_count_stats WHERE short_url = $1`
query := `select * from public.stats_ip_sum WHERE short_url = $1`
err := DbGet(query, &found, url)
return found, err
}
// GetUrlCount 获取短链接总数
func GetUrlCount() (int, error) {
var result int
query := `SELECT count(l.id) FROM public.short_urls l`
var (
result int
query = `SELECT count(l.id) FROM public.short_urls l`
)
// query := `SELECT n_live_tup AS estimate_rows FROM pg_stat_all_tables WHERE relname = 'short_urls'`
return result, DbGet(query, &result)
}
// GetSumOfUrlStats 获取所有短链接的访问量统计信息
func GetSumOfUrlStats() (core.ShortUrlStats, error) {
query := `SELECT * FROM public.stats_sum`
result := core.ShortUrlStats{}
@ -50,19 +56,20 @@ func GetSumOfUrlStats() (core.ShortUrlStats, error) {
return result, nil
}
// GetTop25 获取访问量前 25 的短链接
func GetTop25() ([]core.Top25Url, error) {
query := `SELECT u.*,s.today_count AS today_count,s.d_today_count AS d_today_count FROM public.short_urls u , public.stats_top25 s WHERE u.short_url = s.short_url`
found := []core.Top25Url{}
return found, DbSelect(query, &found)
}
// FindPagedUrlIpCountStats 获取单个短链接的 IP 访问量统计信息
func FindPagedUrlIpCountStats(url string, page int, size int) ([]core.UrlIpCountStats, error) {
found := []core.UrlIpCountStats{}
offset := (page - 1) * size
query := `SELECT s.*,u.id,u.dest_url,u.created_at,u.is_valid,u.memo
FROM public.stats_ip_sum s , public.short_urls u WHERE u.short_url = s.short_url ORDER BY u.created_at DESC LIMIT $1 OFFSET $2`
query := `SELECT s.*,u.id,u.dest_url,u.created_at,u.is_valid,u.memo FROM public.stats_ip_sum s , public.short_urls u WHERE u.short_url = s.short_url ORDER BY u.created_at DESC LIMIT $1 OFFSET $2`
if !utils.EmptyString(url) {
query := `SELECT s.*,u.id,u.dest_url,u.created_at,u.is_valid,u.memo
query := `SELECT s.*,u.id,u.dest_url,u.created_at,u.is_valid,u.memo
FROM public.stats_ip_sum s , public.short_urls u WHERE u.short_url = s.short_url AND u.short_url = $1 ORDER BY u.created_at DESC LIMIT $2 OFFSET $3`
var foundUrl core.UrlIpCountStats
err := DbGet(query, &foundUrl, url, size, offset)
@ -74,33 +81,30 @@ func FindPagedUrlIpCountStats(url string, page int, size int) ([]core.UrlIpCount
return found, DbSelect(query, &found, size, offset)
}
//
// CallProcedureStatsIPSum
// Call scheduled procedures to calculate stats result.
//
// Suggested time interval to call this procedure : 30 ~ 60 minutes
//
func CallProcedureStatsIPSum() error {
query := `SELECT 1 AS r FROM p_stats_ip_sum()`
var r int
return DbGet(query, &r)
}
//
// CallProcedureStatsTop25
// Call scheduled procedures to calculate stats result.
//
// Suggested time interval to call this procedure 5 ~ 10 minutes
//
func CallProcedureStatsTop25() error {
query := `SELECT 2 AS r FROM p_stats_top25()`
var r int
return DbGet(query, &r)
}
//
// CallProcedureStatsSum
// Call scheduled procedures to calculate stats result.
//
// Suggested time interval to call this procedure : 5 ~ 10 minutes
//
func CallProcedureStatsSum() error {
query := `SELECT 3 AS r FROM p_stats_sum()`
var r int

View File

@ -10,28 +10,33 @@ package storage
import (
"fmt"
"ohurlshortener/core"
"ohurlshortener/utils"
)
var Max_Insert_Count = 1000
var MaxInsertCount = 1000
// UpdateShortUrl 更新短链接
func UpdateShortUrl(shortUrl core.ShortUrl) error {
query := `UPDATE public.short_urls SET short_url = :short_url, dest_url = :dest_url, is_valid = :is_valid, memo = :memo WHERE id = :id`
return DbNamedExec(query, shortUrl)
}
// DeleteShortUrl 删除短链接
func DeleteShortUrl(shortUrl core.ShortUrl) error {
query := `DELETE from public.short_urls WHERE short_url = :short_url`
return DbNamedExec(query, shortUrl)
}
// DeleteShortUrlWithAccessLogs 删除短链接以及其访问日志
func DeleteShortUrlWithAccessLogs(shortUrl core.ShortUrl) error {
query1 := fmt.Sprintf(`DELETE from public.short_urls WHERE short_url = '%s'`, shortUrl.ShortUrl)
query2 := fmt.Sprintf(`DELETE from public.access_logs WHERE short_url = '%s'`, shortUrl.ShortUrl)
return DbExecTx(query1, query2)
} //end of Transaction Action
} // end of Transaction Action
// FindShortUrl 根据短链接查找短链接信息
func FindShortUrl(url string) (core.ShortUrl, error) {
found := core.ShortUrl{}
query := `SELECT * FROM public.short_urls WHERE short_url = $1`
@ -39,6 +44,7 @@ func FindShortUrl(url string) (core.ShortUrl, error) {
return found, err
}
// FindAllShortUrls 查找所有短链接
func FindAllShortUrls() ([]core.ShortUrl, error) {
found := []core.ShortUrl{}
query := `SELECT * FROM public.short_urls ORDER BY created_at DESC`
@ -46,6 +52,7 @@ func FindAllShortUrls() ([]core.ShortUrl, error) {
return found, err
}
// FindPagedShortUrls 分页查找短链接
func FindPagedShortUrls(url string, page int, size int) ([]core.ShortUrl, error) {
found := []core.ShortUrl{}
offset := (page - 1) * size
@ -62,6 +69,7 @@ func FindPagedShortUrls(url string, page int, size int) ([]core.ShortUrl, error)
return found, DbSelect(query, &found, size, offset)
}
// InsertShortUrl 插入短链接
func InsertShortUrl(url core.ShortUrl) error {
query := `INSERT INTO public.short_urls (short_url, dest_url, created_at, is_valid, memo)
VALUES(:short_url,:dest_url,:created_at,:is_valid,:memo)`

View File

@ -2,10 +2,11 @@ package storage
import (
"database/sql"
"ohurlshortener/core"
"testing"
"time"
"ohurlshortener/core"
"github.com/bxcodec/faker/v3"
)

View File

@ -1,19 +1,22 @@
package storage
import (
"strings"
"ohurlshortener/core"
"ohurlshortener/utils"
"strings"
"github.com/btcsuite/btcd/btcutil/base58"
)
// FindAllUsers 获取所有用户
func FindAllUsers() ([]core.User, error) {
var found []core.User
query := `SELECT * FROM public.users u`
return found, DbSelect(query, &found)
}
// NewUser 新建用户
func NewUser(account string, password string) error {
query := `INSERT INTO public.users (account, "password") VALUES(:account,:password)`
data, err := PasswordBase58Hash(password)
@ -23,17 +26,20 @@ func NewUser(account string, password string) error {
return DbNamedExec(query, core.User{Account: account, Password: data})
}
// UpdateUser 更新用户
func UpdateUser(user core.User) error {
query := `UPDATE public.users SET account = :account , "password" = :password WHERE id = :id`
return DbNamedExec(query, user)
}
// FindUserByAccount 根据账号查找用户
func FindUserByAccount(account string) (core.User, error) {
var user core.User
query := `SELECT * FROM public.users u WHERE lower(u.account) = $1`
return user, DbGet(query, &user, strings.ToLower(account))
}
// PasswordBase58Hash 密码加密
func PasswordBase58Hash(password string) (string, error) {
data, err := utils.Sha256Of(password)
if err != nil {

View File

@ -1,8 +1,9 @@
package storage
import (
"ohurlshortener/utils"
"testing"
"ohurlshortener/utils"
)
func TestNewUser(t *testing.T) {

View File

@ -20,6 +20,7 @@ var (
RedisConfig RedisConfigInfo
)
// AppConfigInfo 应用配置
type AppConfigInfo struct {
Port int
AdminPort int
@ -27,6 +28,7 @@ type AppConfigInfo struct {
Debug bool
}
// RedisConfigInfo redis配置
type RedisConfigInfo struct {
Host string
User string
@ -35,6 +37,7 @@ type RedisConfigInfo struct {
PoolSize int
}
// DatabaseConfigInfo 数据库配置
type DatabaseConfigInfo struct {
Host string
Port int
@ -45,8 +48,8 @@ type DatabaseConfigInfo struct {
MaxIdleConn int
}
// InitConfig 初始化配置
func InitConfig(file string) (*ini.File, error) {
cfg, err := ini.Load(file)
if err != nil {
return nil, nil

View File

@ -10,9 +10,10 @@ package export
import (
"errors"
"ohurlshortener/core"
"strconv"
"ohurlshortener/core"
"github.com/xuri/excelize/v2"
)
@ -22,7 +23,7 @@ func AccessLogToExcel(logs []core.AccessLog) ([]byte, error) {
}
f := excelize.NewFile()
index := f.NewSheet("Sheet1")
//填充表头
// 填充表头
f.SetCellValue("Sheet1", "A1", "短链接")
f.SetCellValue("Sheet1", "B1", "访问时间")
f.SetCellValue("Sheet1", "C1", "访问IP")

View File

@ -20,6 +20,7 @@ import (
"github.com/btcsuite/btcd/btcutil/base58"
)
// ExitOnError 退出程序
func ExitOnError(message string, err error) {
if err != nil {
log.Printf("[%s] - %s", message, err)
@ -27,12 +28,14 @@ func ExitOnError(message string, err error) {
}
}
// PrintOnError 打印错误
func PrintOnError(message string, err error) {
if err != nil {
log.Printf("[%s] - %s", message, err)
}
}
// RaiseError 返回错误
func RaiseError(message string) error {
if !EmptyString(message) {
return fmt.Errorf(message)
@ -40,11 +43,13 @@ func RaiseError(message string) error {
return nil
}
// EmptyString 判断字符串是否为空
func EmptyString(str string) bool {
str = strings.TrimSpace(str)
return strings.EqualFold(str, "")
}
// UserAgentIpHash 生成用户代理和IP的哈希值
func UserAgentIpHash(useragent string, ip string) string {
input := fmt.Sprintf("%s-%s-%s-%d", useragent, ip, time.Now().String(), rand.Int())
data, _ := Sha256Of(input)
@ -52,6 +57,7 @@ func UserAgentIpHash(useragent string, ip string) string {
return str[:10]
}
// Sha256Of 计算字符串的哈希值
func Sha256Of(input string) ([]byte, error) {
algorithm := sha256.New()
_, err := algorithm.Write([]byte(strings.TrimSpace(input)))
@ -61,6 +67,7 @@ func Sha256Of(input string) ([]byte, error) {
return algorithm.Sum(nil), nil
}
// Base58Encode base58编码
func Base58Encode(data []byte) string {
return base58.Encode(data)
}