Compare commits

...

2 Commits

Author SHA1 Message Date
469f633310 Verify permissions when joining lobby 2024-07-09 18:37:53 -05:00
008cda9abd ws impl. Authenticate via ws 2024-05-25 16:17:48 -05:00
5 changed files with 168 additions and 54 deletions

View File

@ -3,20 +3,26 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/nrednav/cuid2" "log"
"net/http" "net/http"
"github.com/nrednav/cuid2"
) )
// TODO: This struct should have a creation time // TODO: This struct should have a creation time
type Lobby struct { type Lobby struct {
LobbyOwner string LobbyOwner string
LobbyPlayers [3]string LobbyPlayers []string
} }
type LobbyResult struct { type LobbyResult struct {
LobbyId string LobbyId string
} }
type LobbyCreateInput struct {
UserId string
}
// TODO: We should remove entries from this map when they expire. // TODO: We should remove entries from this map when they expire.
// TODO: Define how long lobbies last // TODO: Define how long lobbies last
var lobbies = make(map[string]Lobby) var lobbies = make(map[string]Lobby)
@ -30,9 +36,25 @@ func CreateLobby(writer http.ResponseWriter, request *http.Request) {
return return
} }
// Get the UserId from the JSON payload
decoder := json.NewDecoder(request.Body)
var data LobbyCreateInput
err := decoder.Decode(&data)
if err != nil {
fmt.Printf("Error in JSON decoding: %s\n", err)
writer.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(writer, "{\"error\": \"%s\"}", err)
return
}
lobbyId := cuid2.Generate() lobbyId := cuid2.Generate()
result := LobbyResult{LobbyId: lobbyId} result := LobbyResult{LobbyId: lobbyId}
lobbies[lobbyId] = Lobby{
LobbyOwner: data.UserId,
LobbyPlayers: make([]string, 3),
}
log.Printf("Created lobby with id %s/%s", lobbyId, data.UserId)
jsonData, err := json.Marshal(result) jsonData, err := json.Marshal(result)
if err != nil { if err != nil {
@ -44,3 +66,23 @@ func CreateLobby(writer http.ResponseWriter, request *http.Request) {
fmt.Fprintf(writer, "%s", jsonData) fmt.Fprintf(writer, "%s", jsonData)
} }
// Verifies that the userId has access to the lobbyId
func VerifyLobbyAccess(userId, lobbyId string) bool {
lobby, ok := lobbies[lobbyId]
if !ok {
return false
}
if lobby.LobbyOwner == userId {
return true
}
for _, playerId := range lobby.LobbyPlayers {
if playerId == userId {
return true
}
}
return false
}

View File

@ -1,9 +1,11 @@
package controller package controller
import ( import (
"encoding/json"
"errors"
"log" "log"
"net/http" "net/http"
"time" "strings"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
@ -14,10 +16,15 @@ var upgrader = websocket.Upgrader{
}, },
} }
type LobbyMsg struct {
Action string `json:"action"`
Value string `json:"value"`
}
func LobbyWsConnect(writer http.ResponseWriter, request *http.Request) { func LobbyWsConnect(writer http.ResponseWriter, request *http.Request) {
conn, err := upgrader.Upgrade(writer, request, nil) conn, err := upgrader.Upgrade(writer, request, nil)
if err != nil { if err != nil {
log.Print("upgrade:", err) log.Print("upgrade error:", err)
return return
} }
defer conn.Close() defer conn.Close()
@ -25,18 +32,81 @@ func LobbyWsConnect(writer http.ResponseWriter, request *http.Request) {
for { for {
mt, message, err := conn.ReadMessage() mt, message, err := conn.ReadMessage()
if err != nil { if err != nil {
log.Print("read:", err) log.Print("read error:", err)
break break
} }
log.Printf("recv: %s, type: %d", message, mt) log.Printf("recv: %s, type: %d", message, mt)
time.Sleep(10 * time.Second) var data LobbyMsg
err = json.Unmarshal(message, &data)
err = conn.WriteMessage(mt, message)
if err != nil { if err != nil {
log.Print("write:", err) log.Print("json error:", err)
break
}
switch data.Action {
case "auth":
err = authenticateConnection(mt, conn, data.Value)
default:
log.Print("no action :c")
}
if err != nil {
log.Print("error:", err)
break break
} }
} }
} }
// Verifies that the user id & lobby id are valid, and that the user has permission to
// access the lobby
func authenticateConnection(mt int, conn *websocket.Conn, authInfo string) error {
// TODO: split userId by ','
var err error
var result string
authSections := strings.Split(authInfo, ",")
if len(authSections) != 2 {
err = errors.New("Expected 2 components to auth, in string " + authInfo)
result = "unauthenticated"
} else {
userId := authSections[0]
lobbyId := authSections[1]
if !VerifyLobbyAccess(userId, lobbyId) {
log.Printf("Unathorized to enter lobby: user %s to lobby %s", userId, lobbyId)
result = "unauthenticated"
} else {
_, ok := Users[userId]
// TODO: Verify lobby id
if ok {
result = "authenticated"
} else {
result = "unauthenticated"
}
}
}
if err != nil {
log.Print("auth error: ", err)
}
json, err := json.Marshal(LobbyMsg{
Action: "auth",
Value: result,
})
if err != nil {
log.Print("json marshal: ", err)
return err
}
err = conn.WriteMessage(mt, json)
if err != nil {
log.Print("write error: ", err)
}
return err
}

View File

@ -1,12 +1,21 @@
package controller package controller
import ( import (
"github.com/nrednav/cuid2" "encoding/json"
"fmt"
"net/http" "net/http"
"net/url"
"github.com/nrednav/cuid2"
) )
var Users map[string]string = make(map[string]string) var Users map[string]string = make(map[string]string)
type PersonInfo struct {
UserId string
Username string
}
func Register(username string) string { func Register(username string) string {
uid := cuid2.Generate() uid := cuid2.Generate()
@ -16,6 +25,41 @@ func Register(username string) string {
return uid return uid
} }
func RegisterUser(writer http.ResponseWriter, request *http.Request) {
requestUrl := request.URL
params, err := url.ParseQuery(requestUrl.RawQuery)
if err != nil {
WriteError(err, "Error parsing URL parameters", &writer)
return
}
usernameArr, ok := params["username"]
if !ok {
WriteError(err, "username not found", &writer)
return
}
username := usernameArr[0]
// The result json
result := PersonInfo{
UserId: Register(username),
Username: username,
}
writer.Header().Set("Content-Type", "application/json")
jsonData, err := json.Marshal(result)
if err != nil {
WriteError(err, "Error serializing JSON", &writer)
return
}
writer.WriteHeader(http.StatusOK)
fmt.Fprintf(writer, "%s", jsonData)
}
func ValidateId(writer http.ResponseWriter, request *http.Request) { func ValidateId(writer http.ResponseWriter, request *http.Request) {
if AuthHeaderIsValid(request.Header.Get("Authorization")) { if AuthHeaderIsValid(request.Header.Get("Authorization")) {
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)

View File

@ -10,7 +10,6 @@ func WriteError(err error, message string, writer *http.ResponseWriter) {
fmt.Printf("Error: %s\n", err) fmt.Printf("Error: %s\n", err)
(*writer).WriteHeader(http.StatusInternalServerError) (*writer).WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(*writer, "{\"error\": \"%s\"}", message) fmt.Fprintf(*writer, "{\"error\": \"%s\"}", message)
return
} }
func AuthHeaderIsValid(authHeader string) bool { func AuthHeaderIsValid(authHeader string) bool {
@ -23,7 +22,7 @@ func AuthHeaderIsValid(authHeader string) bool {
bearerToken := reqToken[7:] bearerToken := reqToken[7:]
// Check that the token is in the global map // Check that the token is in the global map
_, ok := (Users)[bearerToken] _, ok := Users[bearerToken]
return ok return ok
} }

43
main.go
View File

@ -2,22 +2,15 @@ package main
import ( import (
"card-jong-be/controller" "card-jong-be/controller"
"encoding/json"
"fmt" "fmt"
"log" "log"
"net/http" "net/http"
"net/url"
"os" "os"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/rs/cors" "github.com/rs/cors"
) )
type PersonInfo struct {
UserId string
Username string
}
func main() { func main() {
fmt.Println("hello SEKAI!!") fmt.Println("hello SEKAI!!")
mainRouter := mux.NewRouter() mainRouter := mux.NewRouter()
@ -25,7 +18,7 @@ func main() {
wsRouter := mainRouter.PathPrefix("/ws").Subrouter() wsRouter := mainRouter.PathPrefix("/ws").Subrouter()
// HTTP routes // HTTP routes
httpRouter.HandleFunc("/register", Register) httpRouter.HandleFunc("/register", controller.RegisterUser)
httpRouter.HandleFunc("/validate", controller.ValidateId) httpRouter.HandleFunc("/validate", controller.ValidateId)
httpRouter.HandleFunc("/lobby/new", controller.CreateLobby).Methods("POST") httpRouter.HandleFunc("/lobby/new", controller.CreateLobby).Methods("POST")
@ -48,37 +41,3 @@ func main() {
log.Fatal(http.ListenAndServe(":"+port, handler)) log.Fatal(http.ListenAndServe(":"+port, handler))
} }
func Register(writer http.ResponseWriter, request *http.Request) {
requestUrl := request.URL
params, err := url.ParseQuery(requestUrl.RawQuery)
if err != nil {
controller.WriteError(err, "Error parsing URL parameters", &writer)
return
}
usernameArr, ok := params["username"]
if !ok {
controller.WriteError(err, "username not found", &writer)
return
}
username := usernameArr[0]
// The result json
result := PersonInfo{
UserId: controller.Register(username),
Username: username,
}
writer.Header().Set("Content-Type", "application/json")
jsonData, err := json.Marshal(result)
if err != nil {
controller.WriteError(err, "Error serializing JSON", &writer)
return
}
writer.WriteHeader(http.StatusOK)
fmt.Fprintf(writer, "%s", jsonData)
}