diff --git a/controller/lobby.go b/controller/lobby.go index 2b6680b..6bc189e 100644 --- a/controller/lobby.go +++ b/controller/lobby.go @@ -3,20 +3,26 @@ package controller import ( "encoding/json" "fmt" - "github.com/nrednav/cuid2" + "log" "net/http" + + "github.com/nrednav/cuid2" ) // TODO: This struct should have a creation time type Lobby struct { LobbyOwner string - LobbyPlayers [3]string + LobbyPlayers []string } type LobbyResult struct { LobbyId string } +type LobbyCreateInput struct { + UserId string +} + // TODO: We should remove entries from this map when they expire. // TODO: Define how long lobbies last var lobbies = make(map[string]Lobby) @@ -30,9 +36,25 @@ func CreateLobby(writer http.ResponseWriter, request *http.Request) { return } + // Get the UserId from the JSON payload + decoder := json.NewDecoder(request.Body) + var data LobbyCreateInput + err := decoder.Decode(&data) + if err != nil { + fmt.Printf("Error in JSON decoding: %s\n", err) + writer.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(writer, "{\"error\": \"%s\"}", err) + return + } + lobbyId := cuid2.Generate() result := LobbyResult{LobbyId: lobbyId} + lobbies[lobbyId] = Lobby{ + LobbyOwner: data.UserId, + LobbyPlayers: make([]string, 3), + } + log.Printf("Created lobby with id %s/%s", lobbyId, data.UserId) jsonData, err := json.Marshal(result) if err != nil { @@ -44,3 +66,23 @@ func CreateLobby(writer http.ResponseWriter, request *http.Request) { fmt.Fprintf(writer, "%s", jsonData) } + +// Verifies that the userId has access to the lobbyId +func VerifyLobbyAccess(userId, lobbyId string) bool { + lobby, ok := lobbies[lobbyId] + if !ok { + return false + } + + if lobby.LobbyOwner == userId { + return true + } + + for _, playerId := range lobby.LobbyPlayers { + if playerId == userId { + return true + } + } + + return false +} diff --git a/controller/lobbyws.go b/controller/lobbyws.go index ea5f050..6815ddf 100644 --- a/controller/lobbyws.go +++ b/controller/lobbyws.go @@ -2,8 +2,10 @@ package controller import ( "encoding/json" + "errors" "log" "net/http" + "strings" "github.com/gorilla/websocket" ) @@ -15,8 +17,8 @@ var upgrader = websocket.Upgrader{ } type LobbyMsg struct { - Action string `json:action` - Value string `json:value` + Action string `json:"action"` + Value string `json:"value"` } func LobbyWsConnect(writer http.ResponseWriter, request *http.Request) { @@ -45,7 +47,7 @@ func LobbyWsConnect(writer http.ResponseWriter, request *http.Request) { switch data.Action { case "auth": - err = validateUserId(mt, conn, data.Value) + err = authenticateConnection(mt, conn, data.Value) default: log.Print("no action :c") } @@ -57,18 +59,46 @@ func LobbyWsConnect(writer http.ResponseWriter, request *http.Request) { } } -func validateUserId(mt int, conn *websocket.Conn, userId string) error { - _, ok := Users[userId] +// Verifies that the user id & lobby id are valid, and that the user has permission to +// access the lobby +func authenticateConnection(mt int, conn *websocket.Conn, authInfo string) error { + // TODO: split userId by ',' - var responseJson LobbyMsg + var err error + var result string - if ok { - responseJson = LobbyMsg{Action: "auth", Value: "authorized"} + authSections := strings.Split(authInfo, ",") + if len(authSections) != 2 { + err = errors.New("Expected 2 components to auth, in string " + authInfo) + result = "unauthenticated" } else { - responseJson = LobbyMsg{Action: "auth", Value: "unauthorized"} + userId := authSections[0] + lobbyId := authSections[1] + + if !VerifyLobbyAccess(userId, lobbyId) { + log.Printf("Unathorized to enter lobby: user %s to lobby %s", userId, lobbyId) + result = "unauthenticated" + } else { + _, ok := Users[userId] + + // TODO: Verify lobby id + + if ok { + result = "authenticated" + } else { + result = "unauthenticated" + } + } } - json, err := json.Marshal(responseJson) + if err != nil { + log.Print("auth error: ", err) + } + + json, err := json.Marshal(LobbyMsg{ + Action: "auth", + Value: result, + }) if err != nil { log.Print("json marshal: ", err) return err