Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 52 additions & 9 deletions internals/proxy/middlewares/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/codeshelldev/gotl/pkg/logger"
log "github.com/codeshelldev/gotl/pkg/logger"
"github.com/codeshelldev/gotl/pkg/request"
"github.com/codeshelldev/secured-signal-api/internals/config"
)

Expand All @@ -22,12 +23,12 @@ var Auth Middleware = Middleware{

type AuthMethod struct {
Name string
Authenticate func(req *http.Request, tokens []string) (bool, error)
Authenticate func(w http.ResponseWriter, req *http.Request, tokens []string) (bool, error)
}

var BearerAuth = AuthMethod {
Name: "Bearer",
Authenticate: func(req *http.Request, tokens []string) (bool, error) {
Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (bool, error) {
header := req.Header.Get("Authorization")

headerParts := strings.SplitN(header, " ", 2)
Expand All @@ -50,7 +51,7 @@ var BearerAuth = AuthMethod {

var BasicAuth = AuthMethod {
Name: "Basic",
Authenticate: func(req *http.Request, tokens []string) (bool, error) {
Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (bool, error) {
header := req.Header.Get("Authorization")

if strings.TrimSpace(header) == "" {
Expand Down Expand Up @@ -90,9 +91,50 @@ var BasicAuth = AuthMethod {
},
}

var BodyAuth = AuthMethod {
Name: "Body",
Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (bool, error) {
const authField = "auth"

body, err := request.GetReqBody(req)

if err != nil {
return false, nil
}

body.Write(req)

if body.Empty {
return false, nil
}

value, exists := body.Data[authField]

if !exists {
return false, nil
}

auth, ok := value.(string)

if !ok {
return false, nil
}

if isValidToken(tokens, auth) {
delete(body.Data, authField)

body.Write(req)

return true, nil
}

return false, errors.New("invalid Body token")
},
}

var QueryAuth = AuthMethod {
Name: "Query",
Authenticate: func(req *http.Request, tokens []string) (bool, error) {
Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (bool, error) {
const authQuery = "@authorization"

auth := req.URL.Query().Get(authQuery)
Expand All @@ -117,7 +159,7 @@ var QueryAuth = AuthMethod {

var PathAuth = AuthMethod {
Name: "Path",
Authenticate: func(req *http.Request, tokens []string) (bool, error) {
Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (bool, error) {
parts := strings.Split(req.URL.Path, "/")

if len(parts) == 0 {
Expand Down Expand Up @@ -155,6 +197,7 @@ func authHandler(next http.Handler) http.Handler {
var authChain = NewAuthChain().
Use(BearerAuth).
Use(BasicAuth).
Use(BodyAuth).
Use(QueryAuth).
Use(PathAuth)

Expand All @@ -166,11 +209,11 @@ func authHandler(next http.Handler) http.Handler {

var authToken string

success, _ := authChain.Eval(req, tokens)
success, _ := authChain.Eval(w, req, tokens)

if !success {
logger.Warn("User failed to provide auth")
w.Header().Set("WWW-Authenticate", "Basic realm=\"Login Required\", Bearer realm=\"Access Token Required\"")

http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
Expand Down Expand Up @@ -202,12 +245,12 @@ func (chain *AuthChain) Use(method AuthMethod) *AuthChain {
return chain
}

func (chain *AuthChain) Eval(req *http.Request, tokens []string) (bool, error) {
func (chain *AuthChain) Eval(w http.ResponseWriter, req *http.Request, tokens []string) (bool, error) {
var err error
var success bool

for _, method := range chain.methods {
success, err = method.Authenticate(req, tokens)
success, err = method.Authenticate(w, req, tokens)

if err != nil {
logger.Warn("User failed ", method.Name, " auth: ", err.Error())
Expand Down