diff --git a/library/bignum.c b/library/bignum.c
index 1fea9b6d03..dd23195d36 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -2000,20 +2000,23 @@ int mpi_gen_prime( mpi *X, size_t nbits, int dh_flag,
 
         while( 1 )
         {
-            if( ( ret = mpi_is_prime( X, f_rng, p_rng ) ) == 0 )
+            /*
+             * First, check small factors for X and Y
+             * before doing Miller-Rabin on any of them
+             */
+            if( ( ret = mpi_check_small_factors(  X         ) ) == 0 &&
+                ( ret = mpi_check_small_factors( &Y         ) ) == 0 &&
+                ( ret = mpi_miller_rabin(  X, f_rng, p_rng  ) ) == 0 &&
+                ( ret = mpi_miller_rabin( &Y, f_rng, p_rng  ) ) == 0 )
             {
-                if( ( ret = mpi_is_prime( &Y, f_rng, p_rng ) ) == 0 )
-                    break;
-
-                if( ret != POLARSSL_ERR_MPI_NOT_ACCEPTABLE )
-                    goto cleanup;
+                break;
             }
 
             if( ret != POLARSSL_ERR_MPI_NOT_ACCEPTABLE )
                 goto cleanup;
 
             /*
-             * Next candidates. We want to preserve
+             * Next candidates. We want to preserve Y = (X-1) / 2 and
              * Y = 1 mod 2 and Y = 2 mod 3 (eq X = 3 mod 4 and X = 2 mod 3)
              * so up Y by 6 and X by 12.
              */