001    /*
002     * Copyright (C) 2012 eXo Platform SAS.
003     *
004     * This is free software; you can redistribute it and/or modify it
005     * under the terms of the GNU Lesser General Public License as
006     * published by the Free Software Foundation; either version 2.1 of
007     * the License, or (at your option) any later version.
008     *
009     * This software is distributed in the hope that it will be useful,
010     * but WITHOUT ANY WARRANTY; without even the implied warranty of
011     * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
012     * Lesser General Public License for more details.
013     *
014     * You should have received a copy of the GNU Lesser General Public
015     * License along with this software; if not, write to the Free
016     * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
017     * 02110-1301 USA, or see the FSF site: http://www.fsf.org.
018     */
019    package org.crsh.ssh.term;
020    
021    import org.apache.sshd.SshServer;
022    import org.apache.sshd.common.Session;
023    import org.apache.sshd.server.PasswordAuthenticator;
024    import org.apache.sshd.server.PublickeyAuthenticator;
025    import org.apache.sshd.server.session.ServerSession;
026    import org.crsh.plugin.PluginContext;
027    import org.crsh.auth.AuthenticationPlugin;
028    import org.crsh.ssh.term.scp.SCPCommandFactory;
029    import org.crsh.term.TermLifeCycle;
030    import org.crsh.term.spi.TermIOHandler;
031    import org.crsh.vfs.Resource;
032    
033    import java.security.PublicKey;
034    import java.util.logging.Level;
035    import java.util.logging.Logger;
036    
037    /**
038     * Interesting stuff here : http://gerrit.googlecode.com/git-history/4b9e5e7fb9380cfadd28d7ffe3dc496dc06f5892/gerrit-sshd/src/main/java/com/google/gerrit/sshd/DatabasePubKeyAuth.java
039     */
040    public class SSHLifeCycle extends TermLifeCycle {
041    
042      /** . */
043      public static final Session.AttributeKey<String> USERNAME = new Session.AttributeKey<java.lang.String>();
044    
045      /** . */
046      public static final Session.AttributeKey<String> PASSWORD = new Session.AttributeKey<java.lang.String>();
047    
048      /** . */
049      private final Logger log = Logger.getLogger(SSHLifeCycle.class.getName());
050    
051      /** . */
052      private SshServer server;
053    
054      /** . */
055      private int port;
056    
057      /** . */
058      private Resource key;
059    
060      /** . */
061      private final AuthenticationPlugin authentication;
062    
063      /** . */
064      private Integer localPort;
065    
066      public SSHLifeCycle(PluginContext context, AuthenticationPlugin<?> authentication) {
067        super(context);
068    
069        //
070        this.authentication = authentication;
071      }
072    
073      public int getPort() {
074        return port;
075      }
076    
077      public void setPort(int port) {
078        this.port = port;
079      }
080    
081      /**
082       * Returns the local part after the ssh server has been succesfully bound or null. This is useful when
083       * the port is chosen at random by the system.
084       *
085       * @return the local port
086       */
087      public Integer getLocalPort() {
088              return localPort;
089      }
090      
091      public Resource getKey() {
092        return key;
093      }
094    
095      public void setKey(Resource key) {
096        this.key = key;
097      }
098    
099      @Override
100      protected void doInit() {
101        try {
102    
103          //
104          TermIOHandler handler = getHandler();
105    
106          //
107          SshServer server = SshServer.setUpDefaultServer();
108          server.setPort(port);
109          server.setShellFactory(new CRaSHCommandFactory(handler));
110          server.setCommandFactory(new SCPCommandFactory(getContext()));
111          server.setKeyPairProvider(new URLKeyPairProvider(key));
112    
113          //
114          if (authentication.getCredentialType().equals(String.class)) {
115            @SuppressWarnings("unchecked")
116            final AuthenticationPlugin<String> passwordAuthentication = (AuthenticationPlugin<String>)authentication;
117            server.setPasswordAuthenticator(new PasswordAuthenticator() {
118              public boolean authenticate(String _username, String _password, ServerSession session) {
119                boolean auth;
120                try {
121                  log.log(Level.FINE, "Using authentication plugin " + authentication + " to authenticate user " + _username);
122                  auth = passwordAuthentication.authenticate(_username, _password);
123                } catch (Exception e) {
124                  log.log(Level.SEVERE, "Exception authenticating user " + _username + " in authentication plugin: " + authentication, e);
125                  return false;
126                }
127    
128              // We store username and password in session for later reuse
129              session.setAttribute(USERNAME, _username);
130              session.setAttribute(PASSWORD, _password);
131    
132              //
133              return auth;
134            }
135          });
136          } else if (authentication.getCredentialType().equals(PublicKey.class)) {
137            @SuppressWarnings("unchecked")
138            final AuthenticationPlugin<PublicKey> keyAuthentication = (AuthenticationPlugin<PublicKey>)authentication;
139            server.setPublickeyAuthenticator(new PublickeyAuthenticator() {
140              public boolean authenticate(String username, PublicKey key, ServerSession session) {
141                try {
142                  log.log(Level.FINE, "Using authentication plugin " + authentication + " to authenticate user " + username);
143    
144    
145                  return keyAuthentication.authenticate(username, key);
146                }
147                catch (Exception e) {
148                  log.log(Level.SEVERE, "Exception authenticating user " + username + " in authentication plugin: " + authentication, e);
149                  return false;
150                }
151              }
152            });
153          }
154    
155          //
156          log.log(Level.INFO, "About to start CRaSSHD");
157          server.start();
158          localPort = server.getPort();
159          log.log(Level.INFO, "CRaSSHD started on port " + localPort);
160    
161          //
162          this.server = server;
163        }
164        catch (Throwable e) {
165          log.log(Level.SEVERE, "Could not start CRaSSHD", e);
166        }
167      }
168    
169      @Override
170      protected void doDestroy() {
171        if (server != null) {
172          try {
173            server.stop();
174          }
175          catch (InterruptedException e) {
176            log.log(Level.FINE, "Got an interruption when stopping server", e);
177          }
178        }
179      }
180    }