1 package com.maverick.ssl; 2 3 4 import java.io.EOFException ; 5 import java.io.FileNotFoundException ; 6 import java.io.IOException ; 7 import java.io.InputStream ; 8 import java.io.OutputStream ; 9 import java.nio.ByteBuffer ; 10 11 import javax.net.ssl.SSLContext; 12 import javax.net.ssl.SSLEngine; 13 import javax.net.ssl.SSLEngineResult; 14 import javax.net.ssl.SSLSession; 15 import javax.net.ssl.TrustManager; 16 17 import org.apache.commons.logging.Log; 18 import org.apache.commons.logging.LogFactory; 19 20 22 public class SSLTransportJCE implements SSLTransport { 23 24 private static SSLContext sslContext; 25 private SSLSession session; 26 private SSLEngine engine; 27 private SSLEngineResult.HandshakeStatus hsStatus; 28 private boolean initialHandshake; 29 30 SSLEngineResult res; 31 32 static Log log = LogFactory.getLog(SSLTransportJCE.class); 34 36 private InputStream rawIn; 37 private OutputStream rawOut; 38 39 byte[] buffer; 40 41 ByteBuffer in_outputData; 42 ByteBuffer in_inputData; 43 ByteBuffer out_outputData; 44 45 InputStream sslIn = new SSLInputStream(); 46 OutputStream sslOut = new SSLOutputStream(); 47 48 49 public InputStream getInputStream() { 50 return sslIn; 51 } 52 53 public OutputStream getOutputStream() { 54 return sslOut; 55 } 56 57 public void initialize(InputStream in, OutputStream out) throws IOException { 58 this.rawIn = in; 59 this.rawOut = out; 60 61 try { 63 engine = getSSLContext().createSSLEngine(); 65 66 engine.setUseClientMode(true); 68 69 session = engine.getSession(); 71 engine.beginHandshake(); 72 hsStatus = engine.getHandshakeStatus(); 73 initialHandshake = true; 74 75 buffer = new byte[session.getPacketBufferSize()]; 78 79 performInitialHandshake(); 80 81 in_outputData = ByteBuffer.allocateDirect(session.getPacketBufferSize()); 82 in_inputData = ByteBuffer.allocateDirect(session.getPacketBufferSize()); 83 out_outputData = ByteBuffer.allocateDirect(session.getPacketBufferSize()); 84 in_outputData.flip(); } catch(Exception ex) { 86 throw new IOException (ex.getMessage()); 87 } 88 } 89 90 private void performInitialHandshake() throws IOException { 91 92 95 ByteBuffer dummy = ByteBuffer.allocate(0); 96 ByteBuffer outputBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize()); 97 ByteBuffer inputBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize()); 98 99 while(initialHandshake) { 100 101 switch (hsStatus) { 102 case FINISHED: 103 106 if(log.isDebugEnabled()) 108 log.debug("SSL Handshake finished"); 109 initialHandshake = false; 111 return; 112 113 case NEED_TASK: 114 115 118 if(log.isDebugEnabled()) 120 log.debug("Performing SSLEngine task"); 121 Runnable task; 124 while ((task = engine.getDelegatedTask()) != null) { 125 task.run(); 126 } 127 hsStatus = engine.getHandshakeStatus(); 128 break; 129 130 case NEED_UNWRAP: 131 132 136 if(log.isDebugEnabled()) 138 log.debug("Reading SSL data from raw InputStream"); 139 boolean forceRead = false; 141 142 do { 143 inputBuffer.flip(); 144 145 if(!inputBuffer.hasRemaining() || forceRead) { 146 147 inputBuffer.compact(); 148 int read = rawIn.read(buffer); 149 150 if(read==-1) 151 throw new EOFException ("Unexpected EOF whilst waiting for SSL unwrap"); 152 153 inputBuffer.put(buffer, 0, read); 154 inputBuffer.flip(); 155 } 156 157 res = engine.unwrap(inputBuffer, outputBuffer); 158 159 inputBuffer.compact(); 160 161 forceRead = res.getStatus()==SSLEngineResult.Status.BUFFER_UNDERFLOW; 162 163 } while(forceRead); 164 165 if(res.getStatus()!=SSLEngineResult.Status.OK) { 166 throw new IOException (res.getStatus().toString()); 167 } 168 169 hsStatus = res.getHandshakeStatus(); 170 171 sendOutput(outputBuffer); 172 break; 173 174 case NEED_WRAP: 175 176 180 if(log.isDebugEnabled()) 182 log.debug("Writing SSL data to raw OutputStream"); 183 res = engine.wrap(dummy, outputBuffer); 185 186 if(res.getStatus()!=SSLEngineResult.Status.OK) { 187 throw new IOException (res.getStatus().toString()); 188 } 189 190 hsStatus = res.getHandshakeStatus(); 191 192 sendOutput(outputBuffer); 193 194 break; 195 196 case NOT_HANDSHAKING: 197 200 log.error("doHandshake has caught a NOT_HANDSHAKING state.. This is impossible!"); 202 } 204 } 205 } 206 207 private void sendOutput(ByteBuffer outputBuffer) throws IOException { 208 outputBuffer.flip(); 209 210 int remaining = outputBuffer.remaining(); 211 while(remaining > 0) { 212 outputBuffer.get(buffer, 0, remaining); 213 rawOut.write(buffer, 0, remaining); 214 remaining = outputBuffer.remaining(); 215 } 216 217 outputBuffer.compact(); 218 } 219 220 226 private synchronized static SSLContext getSSLContext() throws IOException { 227 if (sslContext != null) { 228 return sslContext; 229 } 230 initializeSSL(); 231 return sslContext; 232 } 233 234 240 public static void initializeSSL() throws FileNotFoundException , IOException { 241 try { 242 sslContext = SSLContext.getInstance("TLS"); 243 sslContext.init(null, new TrustManager[] { new SSLTransportTrustManager() }, null); 244 } catch (Exception ex) { 245 throw new IOException ("SSL initialization failed: " + ex.getMessage()); 246 } 247 } 248 249 253 public synchronized static void setSSLContext(SSLContext context) { 254 sslContext = context; 255 } 256 257 class SSLInputStream extends InputStream { 258 259 SSLInputStream() { 260 261 } 262 263 public int read() throws IOException { 264 byte[] b = new byte[1]; 265 if(readData(b,0,1)==1) 266 return (int) b[0]; 267 else 268 return -1; 269 } 270 271 public int read(byte[] buf, int off, int len) throws IOException { 272 return readData(buf, off, len); 273 } 274 275 282 public int readData(byte[] buf, int off, int len) throws IOException { 283 284 try { 285 while(true) { 286 287 if(in_outputData.hasRemaining()) { 288 int actualLen = Math.min(len, in_outputData.remaining()); 289 in_outputData.get(buf, off, actualLen); 290 return actualLen; 291 } 292 293 int read = -1; 294 295 do { 296 297 in_inputData.flip(); 298 299 if(!in_inputData.hasRemaining() || res.getStatus()==SSLEngineResult.Status.BUFFER_UNDERFLOW) { 300 301 in_inputData.compact(); 302 read = rawIn.read(buffer, 0, Math.min(in_inputData.remaining(), buffer.length)); 303 304 if(read==-1) 305 return -1; 306 307 if(read==0) 308 return 0; 309 310 in_inputData.put(buffer,0,read); 311 in_inputData.flip(); 312 } 313 314 in_outputData.compact(); 315 res = engine.unwrap(in_inputData, in_outputData); 316 in_outputData.flip(); 317 in_inputData.compact(); 318 319 } while(res.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW); 320 321 } 322 } catch (IOException e) { 323 e.printStackTrace(); 324 throw e; 325 } 326 327 } 328 } 329 330 331 class SSLOutputStream extends OutputStream { 332 333 public void write(int b) throws IOException { 334 writeData(new byte[] { (byte)b }, 0, 1); 335 } 336 337 public void write(byte[] buf, int off, int len) throws IOException { 338 writeData(buf, off, len); 339 } 340 341 void writeData(byte[] buf, int off, int len) throws IOException { 342 343 ByteBuffer source = ByteBuffer.wrap(buf, off, len); 344 345 while(source.remaining() > 0) { 346 res = engine.wrap(source, out_outputData); 347 348 out_outputData.flip(); 349 350 int remaining = out_outputData.remaining(); 351 while(remaining > 0) { 352 out_outputData.get(buffer, 0, remaining); 353 rawOut.write(buffer, 0, remaining); 354 remaining = out_outputData.remaining(); 355 } 356 out_outputData.compact(); 357 } 358 } 359 } 360 361 public void close() throws SSLException { 362 363 try { 364 365 if(!engine.isOutboundDone()) 366 engine.closeOutbound(); 367 368 if(engine.isInboundDone()) 369 engine.closeInbound(); 370 371 } catch (Exception e) { 372 } finally { 373 try { 374 rawIn.close(); 375 } catch(Throwable t) { } 376 try { 377 rawOut.close(); 378 } catch(Throwable t) { } 379 } 380 381 } 382 } | Popular Tags |