WvStreams
wvrsa.cc
00001 /*
00002  * Worldvisions Tunnel Vision Software:
00003  *   Copyright (C) 1997-2002 Net Integration Technologies, Inc.
00004  * 
00005  * RSA cryptography abstractions.
00006  */
00007 #include <assert.h>
00008 #include <openssl/rsa.h>
00009 #include <openssl/pem.h>
00010 #include "wvsslhacks.h"
00011 #include "wvrsa.h"
00012 #include "wvhex.h"
00013 #include "wvfileutils.h"
00014 
00015 /***** WvRSAKey *****/
00016 
00017 WvRSAKey::WvRSAKey()
00018     : debug("RSA", WvLog::Debug5)
00019 {
00020     rsa = NULL;
00021 }
00022 
00023 
00024 WvRSAKey::WvRSAKey(const WvRSAKey &k)
00025     : debug("RSA", WvLog::Debug5)
00026 {
00027     priv = k.priv;
00028 
00029     if (!priv)
00030         rsa = RSAPublicKey_dup(k.rsa);
00031     else
00032         rsa = RSAPrivateKey_dup(k.rsa);
00033 }
00034 
00035 
00036 WvRSAKey::WvRSAKey(struct rsa_st *_rsa, bool _priv)
00037     : debug("RSA", WvLog::Debug5)
00038 {        
00039     if (_rsa == NULL)
00040     {
00041         rsa = NULL;
00042         debug("Initializing with a NULL key.. are you insane?\n");
00043         return;
00044     }
00045 
00046     rsa = _rsa;
00047     priv = _priv;
00048 }
00049 
00050 
00051 WvRSAKey::WvRSAKey(WvStringParm keystr, bool _priv)
00052     : debug("RSA", WvLog::Debug5)
00053 {
00054     rsa = NULL;
00055 
00056     if (_priv)
00057         decode(RsaHex, keystr);
00058     else
00059         decode(RsaPubHex, keystr);
00060 
00061     priv = _priv;
00062 }
00063 
00064 
00065 WvRSAKey::WvRSAKey(int bits)
00066     : debug("RSA", WvLog::Debug5)
00067 {
00068     rsa = RSA_generate_key(bits, 0x10001, NULL, NULL);
00069     priv = true;
00070 }
00071 
00072 
00073 WvRSAKey::~WvRSAKey()
00074 {
00075     if (rsa)
00076         RSA_free(rsa);
00077 }
00078 
00079 
00080 bool WvRSAKey::isok() const
00081 {
00082     return rsa && (!priv || RSA_check_key(rsa) == 1);
00083 }
00084 
00085 
00086 WvString WvRSAKey::encode(const DumpMode mode) const
00087 {
00088     WvString nil;
00089     WvDynBuf retval;
00090     encode(mode, retval);
00091     return retval.getstr();
00092 }
00093 
00094 
00095 void WvRSAKey::encode(const DumpMode mode, WvBuf &buf) const
00096 {
00097     if (!rsa)
00098     {
00099         debug(WvLog::Warning, "Tried to encode RSA key, but RSA key is "
00100               "blank!\n");
00101         return;
00102     }
00103 
00104     if (mode == RsaHex || mode == RsaPubHex)
00105     {
00106         WvDynBuf keybuf;
00107 
00108         if (mode == RsaHex && priv)
00109         {
00110             size_t size = i2d_RSAPrivateKey(rsa, NULL);
00111             unsigned char *key = keybuf.alloc(size);
00112             size_t newsize = i2d_RSAPrivateKey(rsa, & key);
00113             assert(size == newsize);
00114         }
00115         else
00116         {
00117             size_t size = i2d_RSAPublicKey(rsa, NULL);
00118             unsigned char *key = keybuf.alloc(size);
00119             size_t newsize = i2d_RSAPublicKey(rsa, & key);
00120             assert(size == newsize);
00121         }
00122 
00123         buf.putstr(WvString(WvHexEncoder().strflushbuf(keybuf, true)));
00124     }
00125     else
00126     {
00127         BIO *bufbio = BIO_new(BIO_s_mem());
00128         BUF_MEM *bm;
00129         const EVP_CIPHER *enc = EVP_get_cipherbyname("rsa");
00130     
00131         if (mode == RsaPEM)
00132             PEM_write_bio_RSAPrivateKey(bufbio, rsa, enc,
00133                                         NULL, 0, NULL, NULL);
00134         else if (mode == RsaPubPEM)
00135             PEM_write_bio_RSAPublicKey(bufbio, rsa);
00136         else
00137             debug(WvLog::Warning, "Should never happen: tried to encode RSA "
00138                   "key with unsupported mode.");
00139 
00140         BIO_get_mem_ptr(bufbio, &bm);
00141         buf.put(bm->data, bm->length);
00142         BIO_free(bufbio);
00143     }
00144 }
00145 
00146 
00147 void WvRSAKey::decode(const DumpMode mode, WvStringParm encoded)
00148 {
00149     if (!encoded)
00150         return;
00151     
00152     WvDynBuf buf;
00153     buf.putstr(encoded);
00154     decode(mode, buf);
00155 }
00156 
00157 
00158 void WvRSAKey::decode(const DumpMode mode, WvBuf &encoded)
00159 {
00160     debug("Decoding RSA key.\n");
00161 
00162     if (rsa)
00163     {
00164         debug("Replacing already existent RSA key.\n");
00165         RSA_free(rsa);
00166         rsa = NULL;
00167     }
00168     priv = false;
00169 
00170     // we handle hexified keys a bit differently, since
00171     // OpenSSL has no built-in support for them...
00172     if (mode == RsaHex || mode == RsaPubHex)
00173     {
00174         // unhexify the supplied key
00175         WvDynBuf keybuf;
00176         if (!WvHexDecoder().flush(encoded, keybuf, true) || 
00177             keybuf.used() == 0)
00178         {
00179             debug("Couldn't unhexify RSA key.\n");
00180             return;
00181         }
00182     
00183         size_t keylen = keybuf.used();
00184         const unsigned char *key = keybuf.get(keylen);
00185     
00186         // create the RSA struct
00187         if (mode == RsaHex)
00188         {
00189             rsa = wv_d2i_RSAPrivateKey(NULL, &key, keylen);
00190             priv = true;
00191         }
00192         else
00193             rsa = wv_d2i_RSAPublicKey(NULL, &key, keylen);
00194 
00195         return;
00196     }
00197     else
00198     {
00199 
00200         BIO *membuf = BIO_new(BIO_s_mem());
00201         BIO_write(membuf, encoded.get(encoded.used()), encoded.used());
00202 
00203         if (mode == RsaPEM)
00204         {
00205             rsa = PEM_read_bio_RSAPrivateKey(membuf, NULL, NULL, NULL);
00206             priv = true;
00207         }
00208         else if (mode == RsaPubPEM)
00209             rsa = PEM_read_bio_RSAPublicKey(membuf, NULL, NULL, NULL);
00210         else 
00211             debug(WvLog::Warning, "Should never happen: tried to encode RSA "
00212                   "key with unsupported mode.");
00213 
00214         BIO_free_all(membuf);
00215     }
00216 }
00217 
00218 
00219 /***** WvRSAEncoder *****/
00220 
00221 WvRSAEncoder::WvRSAEncoder(Mode _mode, const WvRSAKey & _key) :
00222     mode(_mode), key(_key)
00223 {
00224     if (key.isok() && key.rsa != NULL)
00225         rsasize = RSA_size(key.rsa);
00226     else
00227         rsasize = 0; // BAD KEY! (should assert but would break compatibility)
00228 }
00229 
00230 
00231 WvRSAEncoder::~WvRSAEncoder()
00232 {
00233 }
00234 
00235 
00236 bool WvRSAEncoder::_reset()
00237 {
00238     return true;
00239 }
00240 
00241 
00242 bool WvRSAEncoder::_encode(WvBuf &in, WvBuf &out, bool flush)
00243 {
00244     if (rsasize == 0)
00245     {
00246         // IGNORE BAD KEY!
00247         in.zap();
00248         return false;
00249     }
00250         
00251     bool success = true;
00252     switch (mode)
00253     {
00254         case Encrypt:
00255         case SignEncrypt:
00256         {
00257             // reserve space for PKCS1_PADDING
00258             const size_t maxchunklen = rsasize - 12;
00259             size_t chunklen;
00260             while ((chunklen = in.used()) != 0)
00261             {
00262                 if (chunklen >= maxchunklen)
00263                     chunklen = maxchunklen;
00264                 else if (! flush)
00265                     break;
00266 
00267                 // encrypt a chunk
00268                 const unsigned char *data = in.get(chunklen);
00269                 unsigned char *crypt = out.alloc(rsasize);
00270                 size_t cryptlen = (mode == Encrypt) ?
00271                     RSA_public_encrypt(chunklen,
00272                     const_cast<unsigned char*>(data), crypt,
00273                     key.rsa, RSA_PKCS1_PADDING) :
00274                     RSA_private_encrypt(chunklen,
00275                     const_cast<unsigned char*>(data), crypt,
00276                     key.rsa, RSA_PKCS1_PADDING);
00277                 if (cryptlen != rsasize)
00278                 {
00279                     out.unalloc(rsasize);
00280                     success = false;
00281                 }
00282             }
00283             break;
00284         }
00285         case Decrypt:
00286         case SignDecrypt:
00287         {
00288             const size_t chunklen = rsasize;
00289             while (in.used() >= chunklen)
00290             {
00291                 // decrypt a chunk
00292                 const unsigned char *crypt = in.get(chunklen);
00293                 unsigned char *data = out.alloc(rsasize);
00294                 int cryptlen = (mode == Decrypt) ?
00295                     RSA_private_decrypt(chunklen,
00296                     const_cast<unsigned char*>(crypt), data,
00297                     key.rsa, RSA_PKCS1_PADDING) :
00298                     RSA_public_decrypt(chunklen,
00299                     const_cast<unsigned char*>(crypt), data,
00300                     key.rsa, RSA_PKCS1_PADDING);
00301                 if (cryptlen == -1)
00302                 {
00303                     out.unalloc(rsasize);
00304                     success = false;
00305                 }
00306                 else
00307                     out.unalloc(rsasize - cryptlen);
00308             }
00309             // flush does not make sense for us here
00310             if (flush && in.used() != 0)
00311                 success = false;
00312             break;
00313         }
00314     }
00315     return success;
00316 }
00317 
00318 
00319 /***** WvRSAStream *****/
00320 
00321 WvRSAStream::WvRSAStream(WvStream *_cloned,
00322     const WvRSAKey &_my_key, const WvRSAKey &_their_key,
00323     WvRSAEncoder::Mode readmode, WvRSAEncoder::Mode writemode) :
00324     WvEncoderStream(_cloned)
00325 {
00326     readchain.append(new WvRSAEncoder(readmode, _my_key), true);
00327     writechain.append(new WvRSAEncoder(writemode, _their_key), true);
00328     if (_my_key.isok() && _my_key.rsa)
00329         min_readsize = RSA_size(_my_key.rsa);
00330 }