Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt to stabilize OPC UA unit tests #1797

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ public class OpcuaConfiguration implements PlcConnectionConfiguration {
@Description("TCP encoding options")
private Limits limits;

@ConfigurationParameter("endpoint-host")
@Description("Endpoint host used to establish secure channel.")
private String endpointHost;

@ConfigurationParameter("endpoint-port")
@Description("Endpoint port used to establish secure channel")
private Integer endpointPort;

public String getProtocolCode() {
return protocolCode;
}
Expand Down Expand Up @@ -228,6 +236,14 @@ public long getNegotiationTimeout() {
return negotiationTimeout;
}

public String getEndpointHost() {
return endpointHost;
}

public Integer getEndpointPort() {
return endpointPort;
}

@Override
public String toString() {
return "OpcuaConfiguration{" +
Expand All @@ -240,5 +256,6 @@ public String toString() {
", limits=" + limits +
'}';
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,6 @@ public String getHost() {
return host;
}

public void setHost(String host) {
this.host = host;
}

public String getPort() {
return port;
}
Expand All @@ -126,10 +122,6 @@ public String getEndpoint() {
public String getTransportEndpoint() {
return transportEndpoint;
}

public void setTransportEndpoint(String transportEndpoint) {
this.transportEndpoint = transportEndpoint;
}

public X509Certificate getServerCertificate() {
return serverCertificate;
Expand All @@ -147,6 +139,13 @@ public void setConfiguration(OpcuaConfiguration configuration) {
port = matcher.group("transportPort");
transportEndpoint = matcher.group("transportEndpoint");

if (configuration.getEndpointHost() != null) {
host = configuration.getEndpointHost();
}
if (configuration.getEndpointPort() != null) {
port = String.valueOf(configuration.getEndpointPort());
}

String portAddition = port != null ? ":" + port : "";
endpoint = "opc." + code + "://" + host + portAddition + transportEndpoint;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,24 @@
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.plc4x.java.api.authentication.PlcAuthentication;
import org.apache.plc4x.java.api.authentication.PlcUsernamePasswordAuthentication;
import org.apache.plc4x.java.api.exceptions.PlcRuntimeException;
import org.apache.plc4x.java.opcua.config.OpcuaConfiguration;
import org.apache.plc4x.java.opcua.readwrite.*;
import org.apache.plc4x.java.opcua.security.MessageSecurity;
import org.apache.plc4x.java.opcua.security.SecurityPolicy;
import org.apache.plc4x.java.opcua.security.SecurityPolicy.SignatureAlgorithm;
import org.apache.plc4x.java.spi.generation.*;
Expand All @@ -56,11 +61,8 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static java.util.concurrent.Executors.newSingleThreadExecutor;

public class SecureChannel {

private static final Logger LOGGER = LoggerFactory.getLogger(SecureChannel.class);
Expand Down Expand Up @@ -91,7 +93,7 @@ public class SecureChannel {
private final OpcuaDriverContext driverContext;
private final Conversation conversation;
private ScheduledFuture<?> keepAlive;
private final List<String> endpoints = new ArrayList<>();
private final Set<String> endpoints = new HashSet<>();
private double sessionTimeout;
private long revisedLifetime;

Expand All @@ -117,9 +119,9 @@ public SecureChannel(Conversation conversation, RequestTransactionManager tm, Op
// Generate a list of endpoints we can use.
try {
InetAddress address = InetAddress.getByName(driverContext.getHost());
this.endpoints.add(address.getHostAddress());
this.endpoints.add(address.getHostName());
this.endpoints.add(address.getCanonicalHostName());
this.endpoints.add("opc.tcp://" + address.getHostAddress() + ":" + driverContext.getPort() + driverContext.getTransportEndpoint());
this.endpoints.add("opc.tcp://" + address.getHostName() + ":" + driverContext.getPort() + driverContext.getTransportEndpoint());
this.endpoints.add("opc.tcp://" + address.getCanonicalHostName() + ":" + driverContext.getPort() + driverContext.getTransportEndpoint());
} catch (UnknownHostException e) {
LOGGER.warn("Unable to resolve host name. Using original host from connection string which may cause issues connecting to server");
this.endpoints.add(driverContext.getHost());
Expand Down Expand Up @@ -313,23 +315,24 @@ private CompletableFuture<ActivateSessionResponse> onConnectActivateSessionReque
conversation.setRemoteCertificate(getX509Certificate(sessionResponse.getServerCertificate().getStringValue()));
conversation.setRemoteNonce(sessionResponse.getServerNonce().getStringValue());

String[] endpoints = new String[3];
List<String> contactPoints = new ArrayList<>(3);
try {
InetAddress address = InetAddress.getByName(driverContext.getHost());
endpoints[0] = "opc.tcp://" + address.getHostAddress() + ":" + driverContext.getPort() + driverContext.getTransportEndpoint();
endpoints[1] = "opc.tcp://" + address.getHostName() + ":" + driverContext.getPort() + driverContext.getTransportEndpoint();
endpoints[2] = "opc.tcp://" + address.getCanonicalHostName() + ":" + driverContext.getPort() + driverContext.getTransportEndpoint();
contactPoints.add("opc.tcp://" + address.getHostAddress() + ":" + driverContext.getPort() + driverContext.getTransportEndpoint());
contactPoints.add("opc.tcp://" + address.getHostName() + ":" + driverContext.getPort() + driverContext.getTransportEndpoint());
contactPoints.add("opc.tcp://" + address.getCanonicalHostName() + ":" + driverContext.getPort() + driverContext.getTransportEndpoint());
} catch (UnknownHostException e) {
LOGGER.debug("error getting host", e);
}

Entry<EndpointDescription, UserTokenPolicy> endpointAndAuthPolicy = selectEndpoint(sessionResponse);
if (endpointAndAuthPolicy == null) {
throw new PlcRuntimeException("Unable to find endpoint - " + endpoints[1]);
Entry<EndpointDescription, UserTokenPolicy> selectedEndpoint = selectEndpoint(sessionResponse.getServerEndpoints(), contactPoints,
configuration.getSecurityPolicy(), configuration.getMessageSecurity());
if (selectedEndpoint == null) {
throw new PlcRuntimeException("Unable to find endpoint matching - " + contactPoints.get(1));
}

PascalString policyId = endpointAndAuthPolicy.getValue().getPolicyId();
UserTokenType tokenType = endpointAndAuthPolicy.getValue().getTokenType();
PascalString policyId = selectedEndpoint.getValue().getPolicyId();
UserTokenType tokenType = selectedEndpoint.getValue().getTokenType();
ExtensionObject userIdentityToken = getIdentityToken(tokenType, policyId.getStringValue());
RequestHeader requestHeader = conversation.createRequestHeader();
SignatureData clientSignature = new SignatureData(NULL_STRING, NULL_BYTE_STRING);
Expand Down Expand Up @@ -421,27 +424,19 @@ public CompletableFuture<EndpointDescription> onDiscoverGetEndpointsRequest() {

return conversation.submit(endpointsRequest, GetEndpointsResponse.class).thenApply(response -> {
List<ExtensionObjectDefinition> endpoints = response.getEndpoints();
MessageSecurityMode effectiveMode = this.configuration.getSecurityPolicy() == SecurityPolicy.NONE ? MessageSecurityMode.messageSecurityModeNone : this.configuration.getMessageSecurity().getMode();
for (ExtensionObjectDefinition endpoint : endpoints) {
EndpointDescription endpointDescription = (EndpointDescription) endpoint;

boolean urlMatch = endpointDescription.getEndpointUrl().getStringValue().equals(this.endpoint.getStringValue());
boolean policyMatch = endpointDescription.getSecurityPolicyUri().getStringValue().equals(this.configuration.getSecurityPolicy().getSecurityPolicyUri());
boolean msgSecurityMatch = endpointDescription.getSecurityMode().equals(effectiveMode);

LOGGER.debug("Validate OPC UA endpoint {} during discovery phase."
+ "Expected {}. Endpoint policy {} looking for {}. Message security {}, looking for {}", endpointDescription.getEndpointUrl().getStringValue(), this.endpoint.getStringValue(),
endpointDescription.getSecurityPolicyUri().getStringValue(), configuration.getSecurityPolicy().getSecurityPolicyUri(),
endpointDescription.getSecurityMode(), configuration.getMessageSecurity().getMode());

if (urlMatch && policyMatch && msgSecurityMatch) {
LOGGER.info("Found OPC UA endpoint {}", this.endpoint.getStringValue());
return endpointDescription;
}
Entry<EndpointDescription, UserTokenPolicy> entry = selectEndpoint(response.getEndpoints(), this.endpoints, this.configuration.getSecurityPolicy(), this.configuration.getMessageSecurity());

if (entry == null) {
Set<String> endpointUris = endpoints.stream()
.filter(EndpointDescription.class::isInstance)
.map(EndpointDescription.class::cast)
.map(EndpointDescription::getEndpointUrl)
.map(PascalString::getStringValue)
.collect(Collectors.toSet());
throw new IllegalArgumentException("Could not find endpoint matching client configuration. Tested " + endpointUris + ". "
+ "Was looking for " + this.endpoint.getStringValue() + " " + this.configuration.getSecurityPolicy().getSecurityPolicyUri() + " " + this.configuration.getMessageSecurity().getMode());
}

throw new IllegalArgumentException("Could not find endpoint matching client configuration. Tested " + endpoints.size() + " endpoints. "
+ "None matched " + this.endpoint.getStringValue() + " " + this.configuration.getSecurityPolicy().getSecurityPolicyUri() + " " + this.configuration.getMessageSecurity().getMode());
return entry.getKey();
});
}

Expand Down Expand Up @@ -503,32 +498,49 @@ private static ReadBufferByteBased toBuffer(Supplier<Payload> supplier) {
/**
* Selects the endpoint and authentication policy based on client settings.
*
* @param sessionResponse - The CreateSessionResponse message returned by the server
* @return Entry representing desired server endpoint and user token policy to access it.
* @param extensionObjects Endpoint descriptions returned by the server.
* @param contactPoints Contact points expected by client.
* @param securityPolicy Security policy searched in endpoints.
* @param messageSecurity Message security needed by client.
* @return Endpoint matching given.
*/
private Entry<EndpointDescription, UserTokenPolicy> selectEndpoint(CreateSessionResponse sessionResponse) {
private Entry<EndpointDescription, UserTokenPolicy> selectEndpoint(List<ExtensionObjectDefinition> extensionObjects, Collection<String> contactPoints,
SecurityPolicy securityPolicy, MessageSecurity messageSecurity) throws PlcRuntimeException {
// Get a list of the endpoints which match ours.
EndpointDescription selectedEndpoint = null;
for (ExtensionObjectDefinition endpoint : sessionResponse.getServerEndpoints()) {
if (!(endpoint instanceof EndpointDescription)) {
MessageSecurityMode effectiveMessageSecurity = SecurityPolicy.NONE == securityPolicy ? MessageSecurityMode.messageSecurityModeNone : messageSecurity.getMode();
List<Entry<EndpointDescription, UserTokenPolicy>> serverEndpoints = new ArrayList<>();

for (ExtensionObjectDefinition extensionObject : extensionObjects) {
if (!(extensionObject instanceof EndpointDescription)) {
continue;
}
if (isEndpoint((EndpointDescription) endpoint)) {
selectedEndpoint = (EndpointDescription) endpoint;
break;

EndpointDescription endpointDescription = (EndpointDescription) extensionObject;
if (isMatchingEndpoint(endpointDescription, contactPoints)) {
boolean policyMatch = endpointDescription.getSecurityPolicyUri().getStringValue().equals(securityPolicy.getSecurityPolicyUri());
boolean msgSecurityMatch = endpointDescription.getSecurityMode().equals(effectiveMessageSecurity);

if (!policyMatch && !msgSecurityMatch) {
continue;
}

for (ExtensionObjectDefinition objectDefinition : endpointDescription.getUserIdentityTokens()) {
if (objectDefinition instanceof UserTokenPolicy) {
UserTokenPolicy userTokenPolicy = (UserTokenPolicy) objectDefinition;
if (isUserTokenPolicyCompatible(userTokenPolicy, this.username)) {
serverEndpoints.add(entry(endpointDescription, userTokenPolicy));
}
}
}
}
}

for (ExtensionObjectDefinition tokenPolicy : selectedEndpoint.getUserIdentityTokens()) {
if (!(tokenPolicy instanceof UserTokenPolicy)) {
continue;
}
if (hasIdentity((UserTokenPolicy) tokenPolicy)) {
return entry(selectedEndpoint, (UserTokenPolicy) tokenPolicy);
}
if (serverEndpoints.isEmpty()) {
return null;
}

return null;
serverEndpoints.sort(Comparator.comparing(e -> e.getKey().getSecurityLevel()));
return serverEndpoints.get(0);
}

/**
Expand All @@ -539,36 +551,14 @@ private Entry<EndpointDescription, UserTokenPolicy> selectEndpoint(CreateSession
* @return true if this endpoint matches our configuration
* @throws PlcRuntimeException - If the returned endpoint string doesn't match the format expected
*/
private boolean isEndpoint(EndpointDescription endpoint) throws PlcRuntimeException {
private static boolean isMatchingEndpoint(EndpointDescription endpoint, Collection<String> contactPoints) throws PlcRuntimeException {
// Split up the connection string into it's individual segments.
String endpointUri = endpoint.getEndpointUrl().getStringValue();
Matcher matcher = URI_PATTERN.matcher(endpointUri);
if (!matcher.matches()) {
throw new PlcRuntimeException(
"Endpoint " + endpointUri + " returned from the server doesn't match the format '{protocol-code}:({transport-code})?//{transport-host}(:{transport-port})(/{transport-endpoint})'");
}
LOGGER.trace("Using Endpoint {} {} {}", matcher.group("transportHost"), matcher.group("transportPort"), matcher.group("transportEndpoint"));

//When the parameter discovery=false is configured, prefer using the custom address. If the transportEndpoint is empty,
// directly replace it with the TransportEndpoint returned by the server.
if (!configuration.isDiscovery() && StringUtils.isBlank(driverContext.getTransportEndpoint())) {
driverContext.setTransportEndpoint(matcher.group("transportEndpoint"));
return true;
}

if (configuration.isDiscovery() && !this.endpoints.contains(matcher.group("transportHost"))) {
return false;
}

if (!driverContext.getPort().equals(matcher.group("transportPort"))) {
return false;
}

if (!driverContext.getTransportEndpoint().equals(matcher.group("transportEndpoint"))) {
return false;
for (String contactPoint : contactPoints) {
if (endpoint.getEndpointUrl().getStringValue().startsWith(contactPoint)) {
return true;
}
}

return true;
return false;
}

/**
Expand All @@ -577,11 +567,11 @@ private boolean isEndpoint(EndpointDescription endpoint) throws PlcRuntimeExcept
* @param policy - UserTokenPolicy configured for server endpoint.
* @return True if given token policy matches client configuration.
*/
private boolean hasIdentity(UserTokenPolicy policy) {
if ((policy.getTokenType() == UserTokenType.userTokenTypeAnonymous) && this.username == null) {
private static boolean isUserTokenPolicyCompatible(UserTokenPolicy policy, String username) {
if ((policy.getTokenType() == UserTokenType.userTokenTypeAnonymous) && username == null) {
return true;
}
return policy.getTokenType() == UserTokenType.userTokenTypeUserName && this.username != null;
return policy.getTokenType() == UserTokenType.userTokenTypeUserName && username != null;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@

public class ChunkFactory {

public static int SYMMETRIC_SECURITY_HEADER_SIZE = 4;
public static final int ASYMMETRIC_SECURITY_HEADER_SIZE = 59;
public static final int SYMMETRIC_SECURITY_HEADER_SIZE = 4;

public Chunk create(boolean asymmetric, Conversation conversation) {
return create(asymmetric,
Expand All @@ -48,7 +49,7 @@ public Chunk create(boolean asymmetric, boolean encrypted, boolean signed, Secur

if (securityPolicy == SecurityPolicy.NONE) {
return new Chunk(
asymmetric ? 59 : SYMMETRIC_SECURITY_HEADER_SIZE,
asymmetric ? ASYMMETRIC_SECURITY_HEADER_SIZE : SYMMETRIC_SECURITY_HEADER_SIZE,
1,
1,
securityPolicy.getSymmetricSignatureSize(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ public class MiloTestContainer extends GenericContainer<MiloTestContainer> {

private final static Logger logger = LoggerFactory.getLogger(MiloTestContainer.class);

private final static ImageFromDockerfile IMAGE = inlineImage();

public MiloTestContainer() {
super(inlineImage());
super(IMAGE);

waitingFor(Wait.forLogMessage("Server started\\s*", 1));
addFixedExposedPort(12686, 12686);
addExposedPort(12686);
}

private static ImageFromDockerfile inlineImage() {
Expand Down
Loading
Loading