Bug 498938 - Add Levenshtein Edit Distance function to Sqlite so we can use it in queries. r=sdwilsh sr=vlad
authorCurtis Bartley <cbartley@mozilla.com>
Fri, 24 Jul 2009 11:39:34 -0400
changeset 30664 cdf35906c03b33b1782b77884392aa74053b3b3d
parent 30663 3621a1d223bdc04f55565b9452b3808975b3ee9b
child 30665 c023fd5706ff1bb086959ca2a65a1831dcfd6215
push id1
push userroot
push dateTue, 26 Apr 2011 22:38:44 +0000
treeherdermozilla-beta@bfdb6e623a36 [default view] [failures only]
perfherder[talos] [build metrics] [platform microbench] (compared to previous push)
reviewerssdwilsh, vlad
bugs498938
milestone1.9.2a1pre
Bug 498938 - Add Levenshtein Edit Distance function to Sqlite so we can use it in queries. r=sdwilsh sr=vlad
db/sqlite3/src/sqlite.def
storage/src/mozStorageSQLFunctions.cpp
storage/src/mozStorageSQLFunctions.h
storage/test/unit/test_levenshtein.js
--- a/db/sqlite3/src/sqlite.def
+++ b/db/sqlite3/src/sqlite.def
@@ -130,16 +130,17 @@ EXPORTS
         sqlite3_realloc
         sqlite3_release_memory
         sqlite3_reset
         sqlite3_reset_auto_extension
         sqlite3_result_blob
         sqlite3_result_double
         sqlite3_result_error
         sqlite3_result_error16
+        sqlite3_result_error_nomem
         sqlite3_result_int
         sqlite3_result_int64
         sqlite3_result_null
         sqlite3_result_text
         sqlite3_result_text16
         sqlite3_result_text16be
         sqlite3_result_text16le
         sqlite3_result_value
--- a/storage/src/mozStorageSQLFunctions.cpp
+++ b/storage/src/mozStorageSQLFunctions.cpp
@@ -60,17 +60,17 @@ namespace {
  * @param aPatternEnd
  *        An iterator at the end of the pattern to check for.
  * @param aStringItr
  *        An iterator at the start of the string to check for the pattern.
  * @param aStringEnd
  *        An iterator at the end of the string to check for the pattern.
  * @param aEscapeChar
  *        The character to use for escaping symbols in the pattern.
- * @returns 1 if the pattern is found, 0 otherwise.
+ * @return 1 if the pattern is found, 0 otherwise.
  */
 int
 likeCompare(nsAString::const_iterator aPatternItr,
             nsAString::const_iterator aPatternEnd,
             nsAString::const_iterator aStringItr,
             nsAString::const_iterator aStringEnd,
             PRUnichar aEscapeChar)
 {
@@ -144,40 +144,238 @@ likeCompare(nsAString::const_iterator aP
     }
 
     aPatternItr++;
   }
 
   return aStringItr == aStringEnd;
 }
 
+/**
+ * This class manages a dynamic array.  It can represent an array of any 
+ * reasonable size, but if the array is "N" elements or smaller, it will be
+ * stored using fixed space inside the auto array itself.  If the auto array
+ * is a local variable, this internal storage will be allocated cheaply on the
+ * stack, similar to nsAutoString.  If a larger size is requested, the memory
+ * will be dynamically allocated from the heap.  Since the destructor will
+ * free any heap-allocated memory, client code doesn't need to care where the
+ * memory came from.
+ */
+template <class T, size_t N> class AutoArray
+{
+
+public:
+
+  AutoArray(size_t size)
+  : mBuffer(size <= N ? mAutoBuffer : new T[size])
+  {
+  }
+
+  ~AutoArray()
+  { 
+    if (mBuffer != mAutoBuffer)
+      delete[] mBuffer; 
+  }
+
+  /**
+   * Return the pointer to the allocated array.
+   * @note If the array allocation failed, get() will return NULL!
+   *
+   * @return the pointer to the allocated array
+   */
+  T *get() 
+  {
+    return mBuffer; 
+  }
+
+private:
+  T *mBuffer;           // Points to mAutoBuffer if we can use it, heap otherwise.
+  T mAutoBuffer[N];     // The internal memory buffer that we use if we can.
+};
+
+/**
+ * Compute the Levenshtein Edit Distance between two strings.
+ * 
+ * @param aStringS
+ *        a string
+ * @param aStringT
+ *        another string
+ * @param _result
+ *        an outparam that will receive the edit distance between the arguments
+ * @return a Sqlite result code, e.g. SQLITE_OK, SQLITE_NOMEM, etc.
+ */
+int
+levenshteinDistance(const nsAString &aStringS,
+                    const nsAString &aStringT,
+                    int *_result)
+{
+    // Set the result to a non-sensical value in case we encounter an error.
+    *_result = -1;
+
+    const PRUint32 sLen = aStringS.Length();
+    const PRUint32 tLen = aStringT.Length();
+
+    if (sLen == 0) {
+      *_result = tLen;
+      return SQLITE_OK;
+    }
+    if (tLen == 0) {
+      *_result = sLen;
+      return SQLITE_OK;
+    }
+
+    // Notionally, Levenshtein Distance is computed in a matrix.  If we 
+    // assume s = "span" and t = "spam", the matrix would look like this:
+    //    s -->
+    //  t          s   p   a   n
+    //  |      0   1   2   3   4
+    //  V  s   1   *   *   *   *
+    //     p   2   *   *   *   *
+    //     a   3   *   *   *   *
+    //     m   4   *   *   *   *
+    //
+    // Note that the row width is sLen + 1 and the column height is tLen + 1,
+    // where sLen is the length of the string "s" and tLen is the length of "t".
+    // The first row and the first column are initialized as shown, and
+    // the algorithm computes the remaining cells row-by-row, and
+    // left-to-right within each row.  The computation only requires that
+    // we be able to see the current row and the previous one.
+
+    // Allocate memory for two rows.  Use AutoArray's to manage the memory
+    // so we don't have to explicitly free it, and so we can avoid the expense
+    // of memory allocations for relatively small strings.
+    AutoArray<int, nsAutoString::kDefaultStorageSize> row1(sLen + 1);
+    AutoArray<int, nsAutoString::kDefaultStorageSize> row2(sLen + 1);
+
+    // Declare the raw pointers that will actually be used to access the memory.
+    int *prevRow = row1.get();
+    NS_ENSURE_TRUE(prevRow, SQLITE_NOMEM);
+    int *currRow = row2.get();
+    NS_ENSURE_TRUE(currRow, SQLITE_NOMEM);
+
+    // Initialize the first row.
+    for (PRUint32 i = 0; i <= sLen; i++)
+        prevRow[i] = i;
+
+    const PRUnichar *s = aStringS.BeginReading();
+    const PRUnichar *t = aStringT.BeginReading();
+
+    // Compute the empty cells in the "matrix" row-by-row, starting with
+    // the second row.
+    for (PRUint32 ti = 1; ti <= tLen; ti++) {
+
+        // Initialize the first cell in this row.
+        currRow[0] = ti;
+
+        // Get the character from "t" that corresponds to this row.
+        const PRUnichar tch = t[ti - 1];
+
+        // Compute the remaining cells in this row, left-to-right,
+        // starting at the second column (and first character of "s").
+        for (PRUint32 si = 1; si <= sLen; si++) {
+            
+            // Get the character from "s" that corresponds to this column,
+            // compare it to the t-character, and compute the "cost".
+            const PRUnichar sch = s[si - 1];
+            int cost = (sch == tch) ? 0 : 1;
+
+            // ............ We want to calculate the value of cell "d" from
+            // ...ab....... the previously calculated (or initialized) cells
+            // ...cd....... "a", "b", and "c", where d = min(a', b', c').
+            // ............ 
+            int aPrime = prevRow[si - 1] + cost;
+            int bPrime = prevRow[si] + 1;
+            int cPrime = currRow[si - 1] + 1;
+            currRow[si] = NS_MIN(aPrime, NS_MIN(bPrime, cPrime));
+        }
+
+        // Advance to the next row.  The current row becomes the previous
+        // row and we recycle the old previous row as the new current row.
+        // We don't need to re-initialize the new current row since we will
+        // rewrite all of its cells anyway.
+        int *oldPrevRow = prevRow;
+        prevRow = currRow;
+        currRow = oldPrevRow;
+    }
+
+    // The final result is the value of the last cell in the last row.
+    // Note that that's now in the "previous" row, since we just swapped them.
+    *_result = prevRow[sLen];
+    return SQLITE_OK;
+}
+
 } // anonymous namespace
 
 ////////////////////////////////////////////////////////////////////////////////
 //// Exposed Functions
 
 int
 registerFunctions(sqlite3 *aDB)
 {
   struct Functions {
     const char *zName;
     int nArg;
     int enc;
     void *pContext;
     void (*xFunc)(::sqlite3_context*, int, sqlite3_value**);
-  } functions[] = {
-    {"lower", 1, SQLITE_UTF16, 0,        caseFunction},
-    {"lower", 1, SQLITE_UTF8,  0,        caseFunction},
-    {"upper", 1, SQLITE_UTF16, (void*)1, caseFunction},
-    {"upper", 1, SQLITE_UTF8,  (void*)1, caseFunction},
+  };
+  
+  Functions functions[] = {
+    {"lower",               
+      1, 
+      SQLITE_UTF16, 
+      0,        
+      caseFunction},
+    {"lower",               
+      1, 
+      SQLITE_UTF8,  
+      0,        
+      caseFunction},
+    {"upper",               
+      1, 
+      SQLITE_UTF16, 
+      (void*)1, 
+      caseFunction},
+    {"upper",               
+      1, 
+      SQLITE_UTF8,  
+      (void*)1, 
+      caseFunction},
 
-    {"like",  2, SQLITE_UTF16, 0,        likeFunction},
-    {"like",  2, SQLITE_UTF8,  0,        likeFunction},
-    {"like",  3, SQLITE_UTF16, 0,        likeFunction},
-    {"like",  3, SQLITE_UTF8,  0,        likeFunction},
+    {"like",                
+      2, 
+      SQLITE_UTF16, 
+      0,        
+      likeFunction},
+    {"like",                
+      2, 
+      SQLITE_UTF8,  
+      0,        
+      likeFunction},
+    {"like",                
+      3, 
+      SQLITE_UTF16, 
+      0,        
+      likeFunction},
+    {"like",                
+      3, 
+      SQLITE_UTF8,  
+      0,        
+      likeFunction},
+
+    {"levenshteinDistance", 
+      2, 
+      SQLITE_UTF16, 
+      0,        
+      levenshteinDistanceFunction},
+    {"levenshteinDistance", 
+      2, 
+      SQLITE_UTF8,  
+      0,        
+      levenshteinDistanceFunction},
   };
 
   int rv = SQLITE_OK;
   for (size_t i = 0; SQLITE_OK == rv && i < NS_ARRAY_LENGTH(functions); ++i) {
     struct Functions *p = &functions[i];
     rv = ::sqlite3_create_function(aDB, p->zName, p->nArg, p->enc, p->pContext,
                                    p->xFunc, NULL, NULL);
   }
@@ -241,10 +439,45 @@ likeFunction(sqlite3_context *aCtx,
   A.EndReading(endString);
   nsAString::const_iterator itrPattern, endPattern;
   B.BeginReading(itrPattern);
   B.EndReading(endPattern);
   ::sqlite3_result_int(aCtx, likeCompare(itrPattern, endPattern, itrString,
                                          endString, E));
 }
 
+void levenshteinDistanceFunction(sqlite3_context *aCtx,
+                                 int aArgc,
+                                 sqlite3_value **aArgv)
+{
+  NS_ASSERTION(2 == aArgc, "Invalid number of arguments!");
+
+  // If either argument is a SQL NULL, then return SQL NULL.
+  if (::sqlite3_value_type(aArgv[0]) == SQLITE_NULL ||
+      ::sqlite3_value_type(aArgv[1]) == SQLITE_NULL) {
+    ::sqlite3_result_null(aCtx);
+    return;
+  }
+
+  int aLen = ::sqlite3_value_bytes16(aArgv[0]) / sizeof(PRUnichar);
+  const PRUnichar *a = static_cast<const PRUnichar *>(::sqlite3_value_text16(aArgv[0]));
+
+  int bLen = ::sqlite3_value_bytes16(aArgv[1]) / sizeof(PRUnichar);
+  const PRUnichar *b = static_cast<const PRUnichar *>(::sqlite3_value_text16(aArgv[1]));
+
+  // Compute the Levenshtein Distance, and return the result (or error).
+  int distance = -1;
+  const nsDependentString A(a, aLen);
+  const nsDependentString B(b, bLen);
+  int status = levenshteinDistance(A, B, &distance);
+  if (status == SQLITE_OK) {
+    ::sqlite3_result_int(aCtx, distance);    
+  }
+  else if (status == SQLITE_NOMEM) {
+    ::sqlite3_result_error_nomem(aCtx);
+  }
+  else {
+    ::sqlite3_result_error(aCtx, "User function returned error code", -1);
+  }
+}
+
 } // namespace storage
 } // namespace mozilla
--- a/storage/src/mozStorageSQLFunctions.h
+++ b/storage/src/mozStorageSQLFunctions.h
@@ -83,12 +83,27 @@ NS_HIDDEN_(void) caseFunction(sqlite3_co
  *        The number of arguments the function is being called with.
  * @param aArgv
  *        An array of the arguments the functions is being called with.
  */
 NS_HIDDEN_(void) likeFunction(sqlite3_context *aCtx,
                               int aArgc,
                               sqlite3_value **aArgv);
 
+/**
+ * An implementation of the Levenshtein Edit Distance algorithm for use in
+ * Sqlite queries.
+ * 
+ * @param aCtx
+ *        The sqlite_context that this function is being called on.
+ * @param aArgc
+ *        The number of arguments the function is being called with.
+ * @param aArgv
+ *        An array of the arguments the functions is being called with.
+ */
+NS_HIDDEN_(void) levenshteinDistanceFunction(sqlite3_context *aCtx,
+                                             int aArgc,
+                                             sqlite3_value **aArgv);
+
 } // namespace storage
 } // namespace mozilla
 
 #endif // _mozStorageSQLFunctions_h_
new file mode 100644
--- /dev/null
+++ b/storage/test/unit/test_levenshtein.js
@@ -0,0 +1,108 @@
+/* ***** BEGIN LICENSE BLOCK *****
+ * Version: MPL 1.1/GPL 2.0/LGPL 2.1
+ *
+ * The contents of this file are subject to the Mozilla Public License Version
+ * 1.1 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * http://www.mozilla.org/MPL/
+ *
+ * Software distributed under the License is distributed on an "AS IS" basis,
+ * WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
+ * for the specific language governing rights and limitations under the
+ * License.
+ *
+ * The Original Code is Storage Test Code.
+ *
+ * The Initial Developer of the Original Code is
+ * Mozilla Corporation.
+ * Portions created by the Initial Developer are Copyright (C) 2009
+ * the Initial Developer. All Rights Reserved.
+ *
+ * Contributor(s):
+ *   Curtis Bartley <cbartley@mozilla.com> (Original Author)
+ *
+ * Alternatively, the contents of this file may be used under the terms of
+ * either the GNU General Public License Version 2 or later (the "GPL"), or
+ * the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
+ * in which case the provisions of the GPL or the LGPL are applicable instead
+ * of those above. If you wish to allow use of your version of this file only
+ * under the terms of either the GPL or the LGPL, and not to allow others to
+ * use your version of this file under the terms of the MPL, indicate your
+ * decision by deleting the provisions above and replace them with the notice
+ * and other provisions required by the GPL or the LGPL. If you do not delete
+ * the provisions above, a recipient may use your version of this file under
+ * the terms of any one of the MPL, the GPL or the LGPL.
+ *
+ * ***** END LICENSE BLOCK ***** */
+
+// This file tests the Levenshtein Distance function we've registered.
+
+function createUtf16Database()
+{
+  print("Creating the in-memory UTF-16-encoded database.");
+  let conn = getService().openSpecialDatabase("memory");
+  conn.executeSimpleSQL("PRAGMA encoding = 'UTF-16'");
+
+  print("Make sure the encoding was set correctly and is now UTF-16.");
+  let stmt = conn.createStatement("PRAGMA encoding");
+  do_check_true(stmt.executeStep());
+  let enc = stmt.getString(0);
+  stmt.finalize();
+
+  // The value returned will actually be UTF-16le or UTF-16be.
+  do_check_true(enc === "UTF-16le" || enc === "UTF-16be");
+
+  return conn;
+}
+
+function check_levenshtein(db, s, t, expectedDistance)
+{
+  var stmt = db.createStatement("SELECT levenshteinDistance(:s, :t) AS result");
+  stmt.params.s = s;
+  stmt.params.t = t;
+  try {
+    do_check_true(stmt.executeStep());
+    do_check_eq(expectedDistance, stmt.row.result);
+  } 
+  finally {
+    stmt.reset();
+    stmt.finalize();
+  }
+}
+
+function testLevenshtein(db)
+{
+  // Basic tests.
+  check_levenshtein(db, "", "", 0);
+  check_levenshtein(db, "foo", "", 3);
+  check_levenshtein(db, "", "bar", 3);
+  check_levenshtein(db, "yellow", "hello", 2);
+  check_levenshtein(db, "gumbo", "gambol", 2);
+  check_levenshtein(db, "kitten", "sitten", 1);
+  check_levenshtein(db, "sitten", "sittin", 1);
+  check_levenshtein(db, "sittin", "sitting", 1);
+  check_levenshtein(db, "kitten", "sitting", 3);
+  check_levenshtein(db, "Saturday", "Sunday", 3);
+  check_levenshtein(db, "YHCQPGK", "LAHYQQKPGKA", 6);
+
+  // Test SQL NULL handling.
+  check_levenshtein(db, "foo", null, null);
+  check_levenshtein(db, null, "bar", null);
+  check_levenshtein(db, null, null, null);
+  
+  // The levenshteinDistance function allocates temporary memory on the stack
+  // if it can.  Test some strings long enough to force a heap allocation.
+  var dots1000 = Array(1001).join(".");
+  var dashes1000 = Array(1001).join("-");
+  check_levenshtein(db, dots1000, dashes1000, 1000);
+}
+
+function run_test()
+{
+  testLevenshtein(getOpenedDatabase());
+  testLevenshtein(createUtf16Database());
+}
+
+
+
+