mirror of
https://gitea.phreedom.club/localhost_frssoft/bloat.git
synced 2025-05-05 11:28:45 +00:00
Add CSRF protection
This commit is contained in:
parent
5fdc7a59b2
commit
bf2cfaf0ed
13 changed files with 219 additions and 48 deletions
|
@ -11,7 +11,8 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
ErrInvalidSession = errors.New("invalid session")
|
||||
ErrInvalidSession = errors.New("invalid session")
|
||||
ErrInvalidCSRFToken = errors.New("invalid csrf token")
|
||||
)
|
||||
|
||||
type authService struct {
|
||||
|
@ -47,6 +48,14 @@ func (s *authService) getClient(ctx context.Context) (c *model.Client, err error
|
|||
return c, nil
|
||||
}
|
||||
|
||||
func checkCSRF(ctx context.Context, c *model.Client) (err error) {
|
||||
csrfToken, ok := ctx.Value("csrf_token").(string)
|
||||
if !ok || csrfToken != c.Session.CSRFToken {
|
||||
return ErrInvalidCSRFToken
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authService) GetAuthUrl(ctx context.Context, instance string) (
|
||||
redirectUrl string, sessionID string, err error) {
|
||||
return s.Service.GetAuthUrl(ctx, instance)
|
||||
|
@ -184,6 +193,10 @@ func (s *authService) SaveSettings(ctx context.Context, client io.Writer, c *mod
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = checkCSRF(ctx, c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return s.Service.SaveSettings(ctx, client, c, settings)
|
||||
}
|
||||
|
||||
|
@ -192,6 +205,10 @@ func (s *authService) Like(ctx context.Context, client io.Writer, c *model.Clien
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = checkCSRF(ctx, c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return s.Service.Like(ctx, client, c, id)
|
||||
}
|
||||
|
||||
|
@ -200,6 +217,10 @@ func (s *authService) UnLike(ctx context.Context, client io.Writer, c *model.Cli
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = checkCSRF(ctx, c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return s.Service.UnLike(ctx, client, c, id)
|
||||
}
|
||||
|
||||
|
@ -208,6 +229,10 @@ func (s *authService) Retweet(ctx context.Context, client io.Writer, c *model.Cl
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = checkCSRF(ctx, c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return s.Service.Retweet(ctx, client, c, id)
|
||||
}
|
||||
|
||||
|
@ -216,6 +241,10 @@ func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model.
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = checkCSRF(ctx, c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return s.Service.UnRetweet(ctx, client, c, id)
|
||||
}
|
||||
|
||||
|
@ -224,6 +253,10 @@ func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model.
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = checkCSRF(ctx, c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return s.Service.PostTweet(ctx, client, c, content, replyToID, format, visibility, isNSFW, files)
|
||||
}
|
||||
|
||||
|
@ -232,6 +265,10 @@ func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Cli
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = checkCSRF(ctx, c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return s.Service.Follow(ctx, client, c, id)
|
||||
}
|
||||
|
||||
|
@ -240,5 +277,9 @@ func (s *authService) UnFollow(ctx context.Context, client io.Writer, c *model.C
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = checkCSRF(ctx, c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return s.Service.UnFollow(ctx, client, c, id)
|
||||
}
|
||||
|
|
|
@ -78,12 +78,21 @@ func NewService(clientName string, clientScope string, clientWebsite string,
|
|||
}
|
||||
}
|
||||
|
||||
func getRendererContext(s model.Settings) *renderer.Context {
|
||||
func getRendererContext(c *model.Client) *renderer.Context {
|
||||
var settings model.Settings
|
||||
var session model.Session
|
||||
if c != nil {
|
||||
settings = c.Session.Settings
|
||||
session = c.Session
|
||||
} else {
|
||||
settings = *model.NewSettings()
|
||||
}
|
||||
return &renderer.Context{
|
||||
MaskNSFW: s.MaskNSFW,
|
||||
ThreadInNewTab: s.ThreadInNewTab,
|
||||
FluorideMode: s.FluorideMode,
|
||||
DarkMode: s.DarkMode,
|
||||
MaskNSFW: settings.MaskNSFW,
|
||||
ThreadInNewTab: settings.ThreadInNewTab,
|
||||
FluorideMode: settings.FluorideMode,
|
||||
DarkMode: settings.DarkMode,
|
||||
CSRFToken: session.CSRFToken,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -98,9 +107,11 @@ func (svc *service) GetAuthUrl(ctx context.Context, instance string) (
|
|||
}
|
||||
|
||||
sessionID = util.NewSessionId()
|
||||
csrfToken := util.NewCSRFToken()
|
||||
session := model.Session{
|
||||
ID: sessionID,
|
||||
InstanceDomain: instance,
|
||||
CSRFToken: csrfToken,
|
||||
Settings: *model.NewSettings(),
|
||||
}
|
||||
err = svc.sessionRepo.Add(session)
|
||||
|
@ -199,13 +210,6 @@ func (svc *service) GetUserToken(ctx context.Context, sessionID string, c *model
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
/*
|
||||
err = c.AuthenticateToken(ctx, code, svc.clientWebsite+"/oauth_callback")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = svc.sessionRepo.Update(sessionID, c.GetAccessToken(ctx))
|
||||
*/
|
||||
|
||||
return res.AccessToken, nil
|
||||
}
|
||||
|
@ -226,13 +230,7 @@ func (svc *service) ServeErrorPage(ctx context.Context, client io.Writer, c *mod
|
|||
Error: errStr,
|
||||
}
|
||||
|
||||
var s model.Settings
|
||||
if c != nil {
|
||||
s = c.Session.Settings
|
||||
} else {
|
||||
s = *model.NewSettings()
|
||||
}
|
||||
rCtx := getRendererContext(s)
|
||||
rCtx := getRendererContext(c)
|
||||
|
||||
svc.renderer.RenderErrorPage(rCtx, client, data)
|
||||
}
|
||||
|
@ -247,7 +245,7 @@ func (svc *service) ServeSigninPage(ctx context.Context, client io.Writer) (err
|
|||
CommonData: commonData,
|
||||
}
|
||||
|
||||
rCtx := getRendererContext(*model.NewSettings())
|
||||
rCtx := getRendererContext(nil)
|
||||
return svc.renderer.RenderSigninPage(rCtx, client, data)
|
||||
}
|
||||
|
||||
|
@ -334,7 +332,7 @@ func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer,
|
|||
PostContext: postContext,
|
||||
CommonData: commonData,
|
||||
}
|
||||
rCtx := getRendererContext(c.Session.Settings)
|
||||
rCtx := getRendererContext(c)
|
||||
|
||||
err = svc.renderer.RenderTimelinePage(rCtx, client, data)
|
||||
if err != nil {
|
||||
|
@ -416,7 +414,7 @@ func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *mo
|
|||
ReplyMap: replyMap,
|
||||
CommonData: commonData,
|
||||
}
|
||||
rCtx := getRendererContext(c.Session.Settings)
|
||||
rCtx := getRendererContext(c)
|
||||
|
||||
err = svc.renderer.RenderThreadPage(rCtx, client, data)
|
||||
if err != nil {
|
||||
|
@ -478,7 +476,7 @@ func (svc *service) ServeNotificationPage(ctx context.Context, client io.Writer,
|
|||
NextLink: nextLink,
|
||||
CommonData: commonData,
|
||||
}
|
||||
rCtx := getRendererContext(c.Session.Settings)
|
||||
rCtx := getRendererContext(c)
|
||||
|
||||
err = svc.renderer.RenderNotificationPage(rCtx, client, data)
|
||||
if err != nil {
|
||||
|
@ -525,7 +523,7 @@ func (svc *service) ServeUserPage(ctx context.Context, client io.Writer, c *mode
|
|||
NextLink: nextLink,
|
||||
CommonData: commonData,
|
||||
}
|
||||
rCtx := getRendererContext(c.Session.Settings)
|
||||
rCtx := getRendererContext(c)
|
||||
|
||||
err = svc.renderer.RenderUserPage(rCtx, client, data)
|
||||
if err != nil {
|
||||
|
@ -544,7 +542,7 @@ func (svc *service) ServeAboutPage(ctx context.Context, client io.Writer, c *mod
|
|||
data := &renderer.AboutData{
|
||||
CommonData: commonData,
|
||||
}
|
||||
rCtx := getRendererContext(c.Session.Settings)
|
||||
rCtx := getRendererContext(c)
|
||||
|
||||
err = svc.renderer.RenderAboutPage(rCtx, client, data)
|
||||
if err != nil {
|
||||
|
@ -569,7 +567,7 @@ func (svc *service) ServeEmojiPage(ctx context.Context, client io.Writer, c *mod
|
|||
Emojis: emojis,
|
||||
CommonData: commonData,
|
||||
}
|
||||
rCtx := getRendererContext(c.Session.Settings)
|
||||
rCtx := getRendererContext(c)
|
||||
|
||||
err = svc.renderer.RenderEmojiPage(rCtx, client, data)
|
||||
if err != nil {
|
||||
|
@ -594,7 +592,7 @@ func (svc *service) ServeLikedByPage(ctx context.Context, client io.Writer, c *m
|
|||
CommonData: commonData,
|
||||
Users: likers,
|
||||
}
|
||||
rCtx := getRendererContext(c.Session.Settings)
|
||||
rCtx := getRendererContext(c)
|
||||
|
||||
err = svc.renderer.RenderLikedByPage(rCtx, client, data)
|
||||
if err != nil {
|
||||
|
@ -619,7 +617,7 @@ func (svc *service) ServeRetweetedByPage(ctx context.Context, client io.Writer,
|
|||
CommonData: commonData,
|
||||
Users: retweeters,
|
||||
}
|
||||
rCtx := getRendererContext(c.Session.Settings)
|
||||
rCtx := getRendererContext(c)
|
||||
|
||||
err = svc.renderer.RenderRetweetedByPage(rCtx, client, data)
|
||||
if err != nil {
|
||||
|
@ -660,7 +658,7 @@ func (svc *service) ServeFollowingPage(ctx context.Context, client io.Writer, c
|
|||
HasNext: hasNext,
|
||||
NextLink: nextLink,
|
||||
}
|
||||
rCtx := getRendererContext(c.Session.Settings)
|
||||
rCtx := getRendererContext(c)
|
||||
|
||||
err = svc.renderer.RenderFollowingPage(rCtx, client, data)
|
||||
if err != nil {
|
||||
|
@ -701,7 +699,7 @@ func (svc *service) ServeFollowersPage(ctx context.Context, client io.Writer, c
|
|||
HasNext: hasNext,
|
||||
NextLink: nextLink,
|
||||
}
|
||||
rCtx := getRendererContext(c.Session.Settings)
|
||||
rCtx := getRendererContext(c)
|
||||
|
||||
err = svc.renderer.RenderFollowersPage(rCtx, client, data)
|
||||
if err != nil {
|
||||
|
@ -750,7 +748,7 @@ func (svc *service) ServeSearchPage(ctx context.Context, client io.Writer, c *mo
|
|||
HasNext: hasNext,
|
||||
NextLink: nextLink,
|
||||
}
|
||||
rCtx := getRendererContext(c.Session.Settings)
|
||||
rCtx := getRendererContext(c)
|
||||
|
||||
err = svc.renderer.RenderSearchPage(rCtx, client, data)
|
||||
if err != nil {
|
||||
|
@ -770,7 +768,7 @@ func (svc *service) ServeSettingsPage(ctx context.Context, client io.Writer, c *
|
|||
CommonData: commonData,
|
||||
Settings: &c.Session.Settings,
|
||||
}
|
||||
rCtx := getRendererContext(c.Session.Settings)
|
||||
rCtx := getRendererContext(c)
|
||||
|
||||
err = svc.renderer.RenderSettingsPage(rCtx, client, data)
|
||||
if err != nil {
|
||||
|
@ -828,6 +826,7 @@ func (svc *service) getCommonData(ctx context.Context, client io.Writer, c *mode
|
|||
}
|
||||
|
||||
data.HeaderData.NotificationCount = notificationCount
|
||||
data.HeaderData.CSRFToken = c.Session.CSRFToken
|
||||
}
|
||||
|
||||
return
|
||||
|
|
|
@ -160,6 +160,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
|
|||
|
||||
r.HandleFunc("/like/{id}", func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := getContextWithSession(context.Background(), req)
|
||||
ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
|
||||
|
||||
id, _ := mux.Vars(req)["id"]
|
||||
retweetedByID := req.FormValue("retweeted_by_id")
|
||||
|
||||
|
@ -179,6 +181,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
|
|||
|
||||
r.HandleFunc("/unlike/{id}", func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := getContextWithSession(context.Background(), req)
|
||||
ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
|
||||
|
||||
id, _ := mux.Vars(req)["id"]
|
||||
retweetedByID := req.FormValue("retweeted_by_id")
|
||||
|
||||
|
@ -198,6 +202,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
|
|||
|
||||
r.HandleFunc("/retweet/{id}", func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := getContextWithSession(context.Background(), req)
|
||||
ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
|
||||
|
||||
id, _ := mux.Vars(req)["id"]
|
||||
retweetedByID := req.FormValue("retweeted_by_id")
|
||||
|
||||
|
@ -217,6 +223,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
|
|||
|
||||
r.HandleFunc("/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := getContextWithSession(context.Background(), req)
|
||||
ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
|
||||
|
||||
id, _ := mux.Vars(req)["id"]
|
||||
retweetedByID := req.FormValue("retweeted_by_id")
|
||||
|
||||
|
@ -236,6 +244,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
|
|||
|
||||
r.HandleFunc("/fluoride/like/{id}", func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := getContextWithSession(context.Background(), req)
|
||||
ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
|
||||
|
||||
id, _ := mux.Vars(req)["id"]
|
||||
count, err := s.Like(ctx, w, nil, id)
|
||||
if err != nil {
|
||||
|
@ -252,6 +262,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
|
|||
|
||||
r.HandleFunc("/fluoride/unlike/{id}", func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := getContextWithSession(context.Background(), req)
|
||||
ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
|
||||
|
||||
id, _ := mux.Vars(req)["id"]
|
||||
count, err := s.UnLike(ctx, w, nil, id)
|
||||
if err != nil {
|
||||
|
@ -268,6 +280,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
|
|||
|
||||
r.HandleFunc("/fluoride/retweet/{id}", func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := getContextWithSession(context.Background(), req)
|
||||
ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
|
||||
|
||||
id, _ := mux.Vars(req)["id"]
|
||||
count, err := s.Retweet(ctx, w, nil, id)
|
||||
if err != nil {
|
||||
|
@ -284,6 +298,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
|
|||
|
||||
r.HandleFunc("/fluoride/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := getContextWithSession(context.Background(), req)
|
||||
ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
|
||||
|
||||
id, _ := mux.Vars(req)["id"]
|
||||
count, err := s.UnRetweet(ctx, w, nil, id)
|
||||
if err != nil {
|
||||
|
@ -299,14 +315,16 @@ func NewHandler(s Service, staticDir string) http.Handler {
|
|||
}).Methods(http.MethodPost)
|
||||
|
||||
r.HandleFunc("/post", func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := getContextWithSession(context.Background(), req)
|
||||
|
||||
err := req.ParseMultipartForm(4 << 20)
|
||||
if err != nil {
|
||||
s.ServeErrorPage(ctx, w, nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := getContextWithSession(context.Background(), req)
|
||||
ctx = context.WithValue(ctx, "csrf_token",
|
||||
getMultipartFormValue(req.MultipartForm, "csrf_token"))
|
||||
|
||||
content := getMultipartFormValue(req.MultipartForm, "content")
|
||||
replyToID := getMultipartFormValue(req.MultipartForm, "reply_to_id")
|
||||
format := getMultipartFormValue(req.MultipartForm, "format")
|
||||
|
@ -358,6 +376,7 @@ func NewHandler(s Service, staticDir string) http.Handler {
|
|||
|
||||
r.HandleFunc("/follow/{id}", func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := getContextWithSession(context.Background(), req)
|
||||
ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
|
||||
|
||||
id, _ := mux.Vars(req)["id"]
|
||||
|
||||
|
@ -373,6 +392,7 @@ func NewHandler(s Service, staticDir string) http.Handler {
|
|||
|
||||
r.HandleFunc("/unfollow/{id}", func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := getContextWithSession(context.Background(), req)
|
||||
ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
|
||||
|
||||
id, _ := mux.Vars(req)["id"]
|
||||
|
||||
|
@ -442,6 +462,7 @@ func NewHandler(s Service, staticDir string) http.Handler {
|
|||
|
||||
r.HandleFunc("/settings", func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := getContextWithSession(context.Background(), req)
|
||||
ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
|
||||
|
||||
visibility := req.FormValue("visibility")
|
||||
copyScope := req.FormValue("copy_scope") == "true"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue