SEC-1404: Refactored IP subnet matching into IpAddressMatcher class to allow it to be used outside expressions.

This commit is contained in:
Luke Taylor 2010-02-10 15:06:01 +00:00
parent 1ecd3e228b
commit 08c7155ab5
2 changed files with 86 additions and 56 deletions

View File

@ -1,15 +1,11 @@
package org.springframework.security.web.access.expression;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import javax.servlet.http.HttpServletRequest;
import org.springframework.security.access.expression.SecurityExpressionRoot;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.FilterInvocation;
import org.springframework.util.StringUtils;
import org.springframework.security.web.util.IpAddressMatcher;
/**
*
@ -34,57 +30,7 @@ public class WebSecurityExpressionRoot extends SecurityExpressionRoot {
* @return true if the IP address of the current request is in the required range.
*/
public boolean hasIpAddress(String ipAddress) {
int nMaskBits = 0;
if (ipAddress.indexOf('/') > 0) {
String[] addressAndMask = StringUtils.split(ipAddress, "/");
ipAddress = addressAndMask[0];
nMaskBits = Integer.parseInt(addressAndMask[1]);
}
InetAddress requiredAddress = parseAddress(ipAddress);
InetAddress remoteAddress = parseAddress(request.getRemoteAddr());
if (!requiredAddress.getClass().equals(remoteAddress.getClass())) {
throw new IllegalArgumentException("IP Address in expression must be the same type as " +
"version returned by request");
}
if (nMaskBits == 0) {
return remoteAddress.equals(requiredAddress);
}
byte[] remAddr = remoteAddress.getAddress();
byte[] reqAddr = requiredAddress.getAddress();
int oddBits = nMaskBits % 8;
int nMaskBytes = nMaskBits/8 + (oddBits == 0 ? 0 : 1);
byte[] mask = new byte[nMaskBytes];
Arrays.fill(mask, 0, oddBits == 0 ? mask.length : mask.length - 1, (byte)0xFF);
if (oddBits != 0) {
int finalByte = (1 << oddBits) - 1;
finalByte <<= 8-oddBits;
mask[mask.length - 1] = (byte) finalByte;
}
// System.out.println("Mask is " + new sun.misc.HexDumpEncoder().encode(mask));
for (int i=0; i < mask.length; i++) {
if ((remAddr[i] & mask[i]) != (reqAddr[i] & mask[i])) {
return false;
}
}
return true;
return (new IpAddressMatcher(ipAddress).matches(request));
}
private InetAddress parseAddress(String address) {
try {
return InetAddress.getByName(address);
} catch (UnknownHostException e) {
throw new IllegalArgumentException("Failed to parse address" + address, e);
}
}
}

View File

@ -0,0 +1,84 @@
package org.springframework.security.web.util;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import javax.servlet.http.HttpServletRequest;
import org.springframework.util.StringUtils;
/**
* Matches a request based on IP Address or subnet mask matching against the remote address.
*
* @author Luke Taylor
* @since 3.0.2
*/
public class IpAddressMatcher implements RequestMatcher {
private final int nMaskBits;
private final InetAddress requiredAddress;
/**
* Takes a specific IP address or a range specified using the
* IP/Netmask (e.g. 192.168.1.0/24 or 202.24.0.0/14).
*
* @param ipAddress the address or range of addresses from which the request must come.
*/
public IpAddressMatcher(String ipAddress) {
if (ipAddress.indexOf('/') > 0) {
String[] addressAndMask = StringUtils.split(ipAddress, "/");
ipAddress = addressAndMask[0];
nMaskBits = Integer.parseInt(addressAndMask[1]);
} else {
nMaskBits = 0;
}
requiredAddress = parseAddress(ipAddress);
}
public boolean matches(HttpServletRequest request) {
InetAddress remoteAddress = parseAddress(request.getRemoteAddr());
if (!requiredAddress.getClass().equals(remoteAddress.getClass())) {
throw new IllegalArgumentException("IP Address in expression must be the same type as " +
"version returned by request");
}
if (nMaskBits == 0) {
return remoteAddress.equals(requiredAddress);
}
byte[] remAddr = remoteAddress.getAddress();
byte[] reqAddr = requiredAddress.getAddress();
int oddBits = nMaskBits % 8;
int nMaskBytes = nMaskBits/8 + (oddBits == 0 ? 0 : 1);
byte[] mask = new byte[nMaskBytes];
Arrays.fill(mask, 0, oddBits == 0 ? mask.length : mask.length - 1, (byte)0xFF);
if (oddBits != 0) {
int finalByte = (1 << oddBits) - 1;
finalByte <<= 8-oddBits;
mask[mask.length - 1] = (byte) finalByte;
}
// System.out.println("Mask is " + new sun.misc.HexDumpEncoder().encode(mask));
for (int i=0; i < mask.length; i++) {
if ((remAddr[i] & mask[i]) != (reqAddr[i] & mask[i])) {
return false;
}
}
return true;
}
private InetAddress parseAddress(String address) {
try {
return InetAddress.getByName(address);
} catch (UnknownHostException e) {
throw new IllegalArgumentException("Failed to parse address" + address, e);
}
}
}