NIFI-10364: Simplified connection/session handling in SmbjClientService

This closes #6307.

Signed-off-by: Tamas Palfy <tpalfy@apache.org>
This commit is contained in:
Peter Turcsanyi 2022-08-17 13:08:07 +02:00 committed by Tamas Palfy
parent 86f01af60e
commit eaaff4ede9
4 changed files with 73 additions and 75 deletions

View File

@ -96,7 +96,7 @@ import org.apache.nifi.services.smb.SmbListableEntity;
+ "Share root directory. For example, for a given remote location" + "Share root directory. For example, for a given remote location"
+ "smb://HOSTNAME:PORT/SHARE/DIRECTORY, and a file is being listed from " + "smb://HOSTNAME:PORT/SHARE/DIRECTORY, and a file is being listed from "
+ "smb://HOSTNAME:PORT/SHARE/DIRECTORY/sub/folder/file then the path attribute will be set to " + "smb://HOSTNAME:PORT/SHARE/DIRECTORY/sub/folder/file then the path attribute will be set to "
+ "\"DIRECTORY/sub/folder/file\"."), + "\"DIRECTORY/sub/folder\"."),
@WritesAttribute(attribute = SERVICE_LOCATION, description = @WritesAttribute(attribute = SERVICE_LOCATION, description =
"The SMB URL of the share."), "The SMB URL of the share."),
@WritesAttribute(attribute = LAST_MODIFIED_TIME, description = @WritesAttribute(attribute = LAST_MODIFIED_TIME, description =

View File

@ -30,6 +30,11 @@ import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import com.hierynomus.smbj.connection.Connection;
import com.hierynomus.smbj.session.Session;
import com.hierynomus.smbj.share.DiskShare;
import com.hierynomus.smbj.share.Share;
import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.annotation.lifecycle.OnDisabled; import org.apache.nifi.annotation.lifecycle.OnDisabled;
@ -116,16 +121,63 @@ public class SmbjClientProviderService extends AbstractControllerService impleme
@Override @Override
public SmbClientService getClient() throws IOException { public SmbClientService getClient() throws IOException {
final SmbjClientService client = new SmbjClientService(smbClient, authenticationContext, getServiceLocation()); Connection connection = smbClient.connect(hostname, port);
try { try {
client.connectToShare(hostname, port, shareName); return connectToShare(connection);
} catch (IOException e) { } catch (IOException e) {
client.forceFullyCloseConnection(); getLogger().debug("Closing stale connection and trying to create a new one for share " + getServiceLocation());
client.connectToShare(hostname, port, shareName);
closeConnection(connection);
connection = smbClient.connect(hostname, port);
return connectToShare(connection);
}
}
private SmbjClientService connectToShare(Connection connection) throws IOException {
final Session session;
final Share share;
try {
session = connection.authenticate(authenticationContext);
} catch (Exception e) {
throw new IOException("Could not create session for share " + getServiceLocation(), e);
} }
return client; try {
share = session.connectShare(shareName);
} catch (Exception e) {
closeSession(session);
throw new IOException("Could not connect to share " + getServiceLocation(), e);
}
if (!(share instanceof DiskShare)) {
closeSession(session);
throw new IllegalArgumentException("DiskShare not found. Share " + share.getClass().getSimpleName() + " found on " + getServiceLocation());
}
return new SmbjClientService(session, (DiskShare) share, getServiceLocation());
}
private void closeConnection(Connection connection) {
try {
if (connection != null) {
connection.close(true);
}
} catch (Exception e) {
getLogger().error("Could not close connection to {}", getServiceLocation(), e);
}
}
private void closeSession(Session session) {
try {
if (session != null) {
session.close();
}
} catch (Exception e) {
getLogger().error("Could not close session to {}", getServiceLocation(), e);
}
} }
@Override @Override

View File

@ -16,7 +16,6 @@
*/ */
package org.apache.nifi.services.smb; package org.apache.nifi.services.smb;
import static java.lang.String.format;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
import static java.util.stream.StreamSupport.stream; import static java.util.stream.StreamSupport.stream;
@ -27,14 +26,13 @@ import com.hierynomus.mssmb2.SMB2CreateDisposition;
import com.hierynomus.mssmb2.SMB2CreateOptions; import com.hierynomus.mssmb2.SMB2CreateOptions;
import com.hierynomus.mssmb2.SMB2ShareAccess; import com.hierynomus.mssmb2.SMB2ShareAccess;
import com.hierynomus.mssmb2.SMBApiException; import com.hierynomus.mssmb2.SMBApiException;
import com.hierynomus.smbj.SMBClient;
import com.hierynomus.smbj.auth.AuthenticationContext;
import com.hierynomus.smbj.connection.Connection;
import com.hierynomus.smbj.session.Session; import com.hierynomus.smbj.session.Session;
import com.hierynomus.smbj.share.Directory; import com.hierynomus.smbj.share.Directory;
import com.hierynomus.smbj.share.DiskShare; import com.hierynomus.smbj.share.DiskShare;
import com.hierynomus.smbj.share.File; import com.hierynomus.smbj.share.File;
import com.hierynomus.smbj.share.Share; import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.net.URI; import java.net.URI;
@ -42,66 +40,31 @@ import java.util.EnumSet;
import java.util.List; import java.util.List;
import java.util.stream.Stream; import java.util.stream.Stream;
public class SmbjClientService implements SmbClientService { class SmbjClientService implements SmbClientService {
private final static Logger LOGGER = LoggerFactory.getLogger(SmbjClientService.class);
private static final List<String> SPECIAL_DIRECTORIES = asList(".", ".."); private static final List<String> SPECIAL_DIRECTORIES = asList(".", "..");
private static final long UNCATEGORISED_ERROR = -1L; private static final long UNCATEGORISED_ERROR = -1L;
final private AuthenticationContext authenticationContext; private final Session session;
final private SMBClient smbClient; private final DiskShare share;
final private URI serviceLocation; private final URI serviceLocation;
private Connection connection; SmbjClientService(Session session, DiskShare share, URI serviceLocation) {
private Session session; this.session = session;
private DiskShare share; this.share = share;
public SmbjClientService(SMBClient smbClient, AuthenticationContext authenticationContext, URI serviceLocation) {
this.smbClient = smbClient;
this.authenticationContext = authenticationContext;
this.serviceLocation = serviceLocation; this.serviceLocation = serviceLocation;
} }
public void connectToShare(String hostname, int port, String shareName) throws IOException {
Share share;
try {
connection = smbClient.connect(hostname, port);
session = connection.authenticate(authenticationContext);
share = session.connectShare(shareName);
} catch (Exception e) {
close();
throw new IOException("Could not connect to share " + format("%s:%d/%s", hostname, port, shareName), e);
}
if (share instanceof DiskShare) {
this.share = (DiskShare) share;
} else {
close();
throw new IllegalArgumentException("DiskShare not found. Share " +
share.getClass().getSimpleName() + " found on " + format("%s:%d/%s", hostname, port,
shareName));
}
}
public void forceFullyCloseConnection() {
try {
if (connection != null) {
connection.close(true);
}
} catch (IOException ignore) {
} finally {
connection = null;
}
}
@Override @Override
public void close() { public void close() {
try { try {
if (session != null) { if (session != null) {
session.close(); session.close();
} }
} catch (IOException ignore) { } catch (Exception e) {
LOGGER.error("Could not close session to {}", serviceLocation, e);
} finally {
session = null;
} }
} }

View File

@ -20,9 +20,6 @@ import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.hierynomus.smbj.SMBClient;
import com.hierynomus.smbj.auth.AuthenticationContext;
import com.hierynomus.smbj.connection.Connection;
import com.hierynomus.smbj.session.Session; import com.hierynomus.smbj.session.Session;
import com.hierynomus.smbj.share.DiskShare; import com.hierynomus.smbj.share.DiskShare;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -33,20 +30,11 @@ import org.mockito.MockitoAnnotations;
class NiFiSmbjClientTest { class NiFiSmbjClientTest {
@Mock
DiskShare share;
@Mock
SMBClient smbClient;
@Mock
AuthenticationContext authenticationContext;
@Mock @Mock
Session session; Session session;
@Mock @Mock
Connection connection; DiskShare share;
@InjectMocks @InjectMocks
SmbjClientService underTest; SmbjClientService underTest;
@ -58,17 +46,12 @@ class NiFiSmbjClientTest {
@Test @Test
public void shouldCreateDirectoriesRecursively() throws Exception { public void shouldCreateDirectoriesRecursively() throws Exception {
when(smbClient.connect("hostname", 445))
.thenReturn(connection);
when(connection.authenticate(authenticationContext)).thenReturn(session);
when(session.connectShare(anyString())).thenReturn(share); when(session.connectShare(anyString())).thenReturn(share);
when(share.fileExists("directory")).thenReturn(true); when(share.fileExists("directory")).thenReturn(true);
when(share.fileExists("path")).thenReturn(false); when(share.fileExists("path")).thenReturn(false);
when(share.fileExists("to")).thenReturn(false); when(share.fileExists("to")).thenReturn(false);
when(share.fileExists("create")).thenReturn(false); when(share.fileExists("create")).thenReturn(false);
underTest.connectToShare("hostname", 445, "share");
underTest.createDirectory("directory/path/to/create"); underTest.createDirectory("directory/path/to/create");
verify(share).mkdir("directory/path"); verify(share).mkdir("directory/path");