Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RESTWS-844: Expire user sessions authenticated with old password. #486

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
*/
package org.openmrs.module.webservices.rest.web.v1_0.controller.openmrs1_8;

import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.openmrs.User;
import org.openmrs.api.APIAuthenticationException;
import org.openmrs.api.APIException;
Expand All @@ -27,39 +26,46 @@
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.bind.annotation.*;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@Controller
@RequestMapping(value = "/rest/" + RestConstants.VERSION_1 + "/password")
public class ChangePasswordController1_8 extends BaseRestController {

@Qualifier("userService")
@Autowired
private UserService userService;


private final Log log = LogFactory.getLog(getClass());

@RequestMapping(method = RequestMethod.POST)
@ResponseStatus(HttpStatus.OK)
public void changeOwnPassword(@RequestBody Map<String, String> body) {
public void changeOwnPassword(@RequestBody Map<String, String> body, HttpServletRequest servletRequest) {
String oldPassword = body.get("oldPassword");
String newPassword = body.get("newPassword");
if (!Context.isAuthenticated()) {
throw new APIAuthenticationException("Must be authenticated to change your own password");
}
try {
userService.changePassword(oldPassword, newPassword);
}
catch (APIException ex) {
SessionListener.invalidateOtherSessions(Context.getAuthenticatedUser().getUuid(), servletRequest.getSession());
} catch (APIException ex) {
// this happens if they give the wrong oldPassword
log.error("Change password failed", ex);
throw new ValidationException(ex.getMessage());
} catch (Exception e) {
log.error("Change password failed", e);
}
}

@RequestMapping(value = "/{userUuid}", method = RequestMethod.POST)
@ResponseStatus(HttpStatus.OK)
public void changeOthersPassword(@PathVariable("userUuid") String userUuid, @RequestBody Map<String, String> body) {
Expand All @@ -69,26 +75,71 @@ public void changeOthersPassword(@PathVariable("userUuid") String userUuid, @Req
User user;
try {
user = userService.getUserByUuid(userUuid);
}
finally {
} finally {
Context.removeProxyPrivilege(PrivilegeConstants.VIEW_USERS);
Context.removeProxyPrivilege("Get Users");
}

if (user == null || user.getUserId() == null) {
throw new NullPointerException();
} else {
userService.changePassword(user, newPassword);
SessionListener.invalidateAllSessions(user.getUuid());
}
}

// This probably belongs in the base class, but we don't want to test all the behaviors that would change
@ExceptionHandler(NullPointerException.class)
@ResponseBody
public SimpleObject handleNotFound(NullPointerException exception, HttpServletRequest request,
HttpServletResponse response) {
HttpServletResponse response) {
response.setStatus(HttpServletResponse.SC_NOT_FOUND);
return RestUtil.wrapErrorResponse(exception, "User not found");
}


static class SessionListener {
private static final Log log = LogFactory.getLog(SessionListener.class);

private static final Map<String, List<HttpSession>> map = new HashMap<>();

public static void sessionCreated(String userUuid, HttpSession httpSession) {
if (!map.containsKey(userUuid))
map.put(userUuid, new ArrayList<>());

List<HttpSession> sessions = map.get(userUuid);
if (sessions.contains(httpSession))
return;

sessions.add(httpSession);
log.info(String.format("Added new session. Total sessions for user: %s = %d", userUuid, map.get(userUuid).size()));
}

public static void invalidateOtherSessions(String userUuid, HttpSession currentSession) {
log.info(String.format("Finding other sessions for the user: %s, for session: %s", userUuid, currentSession));
List<HttpSession> sessions = map.get(userUuid);
for (HttpSession session : sessions) {
if (!currentSession.getId().equals(session.getId())) {
session.invalidate();
}
}
ArrayList<HttpSession> httpSessions = new ArrayList<>();
httpSessions.add(currentSession);
map.put(userUuid, httpSessions);
log.info(String.format("Invalidated %d other sessions for the user with this session", sessions.size() - 1));
}

public static void invalidateAllSessions(String userUuid) {
log.info(String.format("Finding other sessions for the user: %s", userUuid));

List<HttpSession> sessions = map.get(userUuid);
if (sessions == null) {
log.info("No sessions found for this user");
return;
}

sessions.forEach(HttpSession::invalidate);
map.remove(userUuid);
log.info(String.format("Found %d sessions for the user: %s", sessions.size(), userUuid));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
import java.util.Set;

import org.apache.commons.lang3.LocaleUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.openmrs.User;
import org.openmrs.api.APIException;
import org.openmrs.api.context.Context;
import org.openmrs.module.webservices.rest.SimpleObject;
import org.openmrs.module.webservices.rest.web.ConversionUtil;
import org.openmrs.module.webservices.rest.web.RestConstants;
import org.openmrs.module.webservices.rest.web.api.RestService;
import org.openmrs.module.webservices.rest.web.representation.CustomRepresentation;
import org.openmrs.module.webservices.rest.web.representation.Representation;
import org.openmrs.module.webservices.rest.web.v1_0.controller.BaseRestController;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -33,41 +37,47 @@
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.context.request.WebRequest;

import javax.servlet.http.HttpServletRequest;

/**
* Controller that lets a client check the status of their session, and log out. (Authenticating is
* handled through a filter, and may happen through this or any other resource.
*/
@Controller
@RequestMapping(value = "/rest/" + RestConstants.VERSION_1 + "/session")
public class SessionController1_8 extends BaseRestController {

@Autowired
RestService restService;


private final Log log = LogFactory.getLog(getClass());

/**
* Tells the user their sessionId, and whether or not they are authenticated.
*
*
* @param request
* @return
* @should return the session id if the user is authenticated
* @should return the session id if the user is not authenticated
*/
@RequestMapping(method = RequestMethod.GET)
@ResponseBody
public Object get(WebRequest request) {
public Object get(WebRequest request, HttpServletRequest httpServletRequest) {
boolean authenticated = Context.isAuthenticated();
SimpleObject session = new SimpleObject();
session.add("sessionId", request.getSessionId()).add("authenticated", authenticated);
if (authenticated) {
String repParam = request.getParameter(RestConstants.REQUEST_PROPERTY_FOR_REPRESENTATION);
Representation rep = (repParam != null) ? restService.getRepresentation(repParam) : Representation.DEFAULT;
session.add("user", ConversionUtil.convertToRepresentation(Context.getAuthenticatedUser(), rep));
session.add("locale", Context.getLocale());
User user = Context.getAuthenticatedUser();
session.add("locale", Context.getLocale());
session.add("allowedLocales", Context.getAdministrationService().getAllowedLocales());
session.add("user", ConversionUtil.convertToRepresentation(user, rep));
ChangePasswordController1_8.SessionListener.sessionCreated(user.getUuid(), httpServletRequest.getSession(false));
}
return session;
}

@RequestMapping(method = RequestMethod.POST)
@ResponseBody
@ResponseStatus(value = HttpStatus.OK)
Expand All @@ -87,10 +97,10 @@ public void post(@RequestBody Map<String, String> body) {
throw new APIException(" '" + localeStr + "' is not in the list of allowed locales.");
}
}

/**
* Logs the client out
*
*
* @should log the client out
*/
@RequestMapping(method = RequestMethod.DELETE)
Expand All @@ -99,5 +109,5 @@ public void post(@RequestBody Map<String, String> body) {
public void delete() {
Context.logout();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,24 @@
import org.springframework.web.context.request.WebRequest;

public class SessionController1_8Test extends BaseModuleWebContextSensitiveTest {

private String SESSION_ID = "test-session-id";

private SessionController1_8 controller;
private WebRequest request;

private ServletWebRequest request;

@Before
public void before() {
controller = new SessionController1_8();
MockHttpServletRequest hsr = new MockHttpServletRequest();
hsr.setSession(new MockHttpSession(new MockServletContext(), SESSION_ID));
request = new ServletWebRequest(hsr);

Context.getAdministrationService().saveGlobalProperty(
new GlobalProperty(OpenmrsConstants.GLOBAL_PROPERTY_LOCALE_ALLOWED_LIST, "en_GB, sp, fr"));
}

/**
* @see SessionController1_8#delete()
* @verifies log the client out
Expand All @@ -58,15 +58,15 @@ public void delete_shouldLogTheClientOut() throws Exception {
controller.delete();
Assert.assertFalse(Context.isAuthenticated());
}

/**
* @see SessionController1_8#get(WebRequest)
* @see SessionController1_8#get(WebRequest, javax.servlet.http.HttpServletRequest)
* @verifies return the session id if the user is authenticated
*/
@Test
public void get_shouldReturnTheSessionIdAndUserIfTheUserIsAuthenticated() throws Exception {
Assert.assertTrue(Context.isAuthenticated());
Object ret = controller.get(request);
Object ret = controller.get(request, request.getRequest());
Object userProp = PropertyUtils.getProperty(ret, "user");
Assert.assertEquals(SESSION_ID, PropertyUtils.getProperty(ret, "sessionId"));
Assert.assertEquals(true, PropertyUtils.getProperty(ret, "authenticated"));
Expand All @@ -75,29 +75,29 @@ public void get_shouldReturnTheSessionIdAndUserIfTheUserIsAuthenticated() throws
Assert.assertEquals(Context.getAuthenticatedUser().getPerson().getUuid(),
PropertyUtils.getProperty(personProp, "uuid"));
}

@Test
public void get_shouldReturnLocaleInfoIfTheUserIsAuthenticated() throws Exception {
Assert.assertTrue(Context.isAuthenticated());
Object ret = controller.get(request);
Object ret = controller.get(request, request.getRequest());
Assert.assertEquals(Context.getLocale(), PropertyUtils.getProperty(ret, "locale"));
Assert.assertArrayEquals(Context.getAdministrationService().getAllowedLocales().toArray(),
((List<Locale>) PropertyUtils.getProperty(ret, "allowedLocales")).toArray());
}

/**
* @see SessionController1_8#get(WebRequest)
* @see SessionController1_8#get(WebRequest, javax.servlet.http.HttpServletRequest)
* @verifies return the session id if the user is not authenticated
*/
@Test
public void get_shouldReturnTheSessionIdIfTheUserIsNotAuthenticated() throws Exception {
Context.logout();
Assert.assertFalse(Context.isAuthenticated());
Object ret = controller.get(request);
Object ret = controller.get(request, request.getRequest());
Assert.assertEquals(SESSION_ID, PropertyUtils.getProperty(ret, "sessionId"));
Assert.assertEquals(false, PropertyUtils.getProperty(ret, "authenticated"));
}

@Test
public void post_shouldSetTheUserLocale() throws Exception {
Locale newLocale = new Locale("sp");
Expand All @@ -106,14 +106,14 @@ public void post_shouldSetTheUserLocale() throws Exception {
controller.post(new ObjectMapper().readValue(content, HashMap.class));
Assert.assertEquals(newLocale, Context.getLocale());
}

@Test(expected = APIException.class)
public void post_shouldFailWhenSettingIllegalLocale() throws Exception {
String newLocale = "fOOb@r:";
String content = "{\"locale\":\"" + newLocale + "\"}";
controller.post(new ObjectMapper().readValue(content, HashMap.class));
}

@Test(expected = APIException.class)
public void post_shouldFailWhenSettingDisallowedLocale() throws Exception {
String newLocale = "km_KH";
Expand Down
Loading