module Text.EditDistance.STUArray (
levenshteinDistance, levenshteinDistanceWithLengths, restrictedDamerauLevenshteinDistance, restrictedDamerauLevenshteinDistanceWithLengths
) where
import Text.EditDistance.EditCosts
import Text.EditDistance.MonadUtilities
import Control.Monad
import Control.Monad.ST
import Data.Array.ST
levenshteinDistance :: EditCosts -> String -> String -> Int
levenshteinDistance !costs str1 str2 = levenshteinDistanceWithLengths costs str1_len str2_len str1 str2
where
str1_len = length str1
str2_len = length str2
levenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths !costs !str1_len !str2_len str1 str2 = runST (levenshteinDistanceST costs str1_len str2_len str1 str2)
levenshteinDistanceST :: EditCosts -> Int -> Int -> String -> String -> ST s Int
levenshteinDistanceST !costs !str1_len !str2_len str1 str2 = do
str1_array <- stringToArray str1 str1_len
str2_array <- stringToArray str2 str2_len
cost_row <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
cost_row' <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
_ <- (\f -> foldM f 0 ([1..] `zip` str1)) $ \deletion_cost (i, col_char) -> let deletion_cost' = deletion_cost + deletionCost costs col_char in writeArray cost_row i deletion_cost' >> return deletion_cost'
(_, final_row, _) <- foldM (levenshteinDistanceSTRowWorker costs str1_len str1_array str2_array) (0, cost_row, cost_row') [1..str2_len]
readArray final_row str1_len
levenshteinDistanceSTRowWorker :: EditCosts -> Int -> STUArray s Int Char -> STUArray s Int Char -> (Int, STUArray s Int Int, STUArray s Int Int) -> Int -> ST s (Int, STUArray s Int Int, STUArray s Int Int)
levenshteinDistanceSTRowWorker !costs !str1_len !str1_array !str2_array (!insertion_cost, !cost_row, !cost_row') !j = do
row_char <- readArray str2_array j
let insertion_cost' = insertion_cost + insertionCost costs row_char
writeArray cost_row' 0 insertion_cost'
loopM_ 1 str1_len (colWorker row_char)
return (insertion_cost', cost_row', cost_row)
where
colWorker row_char !i = do
col_char <- readArray str1_array i
left_up <- readArray cost_row (i 1)
left <- readArray cost_row' (i 1)
here_up <- readArray cost_row i
let here = standardCosts costs row_char col_char left left_up here_up
writeArray cost_row' i here
restrictedDamerauLevenshteinDistance :: EditCosts -> String -> String -> Int
restrictedDamerauLevenshteinDistance !costs str1 str2 = restrictedDamerauLevenshteinDistanceWithLengths costs str1_len str2_len str1 str2
where
str1_len = length str1
str2_len = length str2
restrictedDamerauLevenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths !costs !str1_len !str2_len str1 str2 = runST (restrictedDamerauLevenshteinDistanceST costs str1_len str2_len str1 str2)
restrictedDamerauLevenshteinDistanceST :: EditCosts -> Int -> Int -> String -> String -> ST s Int
restrictedDamerauLevenshteinDistanceST !costs str1_len str2_len str1 str2 = do
str1_array <- stringToArray str1 str1_len
str2_array <- stringToArray str2 str2_len
cost_row <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
_ <- (\f -> foldM f 0 ([1..] `zip` str1)) $ \deletion_cost (!i, col_char) -> let deletion_cost' = deletion_cost + deletionCost costs col_char in writeArray cost_row i deletion_cost' >> return deletion_cost'
if (str2_len == 0)
then readArray cost_row str1_len
else do
cost_row' <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
cost_row'' <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
row_char <- readArray str2_array 1
let zero = insertionCost costs row_char
writeArray cost_row' 0 zero
loopM_ 1 str1_len (firstRowColWorker str1_array row_char cost_row cost_row')
(_, _, final_row, _, _) <- foldM (restrictedDamerauLevenshteinDistanceSTRowWorker costs str1_len str1_array str2_array) (zero, cost_row, cost_row', cost_row'', row_char) [2..str2_len]
readArray final_row str1_len
where
firstRowColWorker !str1_array !row_char !cost_row !cost_row' !i = do
col_char <- readArray str1_array i
left_up <- readArray cost_row (i 1)
left <- readArray cost_row' (i 1)
here_up <- readArray cost_row i
let here = standardCosts costs row_char col_char left left_up here_up
writeArray cost_row' i here
restrictedDamerauLevenshteinDistanceSTRowWorker :: EditCosts -> Int
-> STUArray s Int Char -> STUArray s Int Char
-> (Int, STUArray s Int Int, STUArray s Int Int, STUArray s Int Int, Char) -> Int
-> ST s (Int, STUArray s Int Int, STUArray s Int Int, STUArray s Int Int, Char)
restrictedDamerauLevenshteinDistanceSTRowWorker !costs !str1_len !str1_array !str2_array (!insertion_cost, !cost_row, !cost_row', !cost_row'', !prev_row_char) !j = do
row_char <- readArray str2_array j
zero_up <- readArray cost_row' 0
let insertion_cost' = insertion_cost + insertionCost costs row_char
writeArray cost_row'' 0 insertion_cost'
when (str1_len > 0) $ do
col_char <- readArray str1_array 1
one_up <- readArray cost_row' 1
let one = standardCosts costs row_char col_char insertion_cost' zero_up one_up
writeArray cost_row'' 1 one
loopM_ 2 str1_len (colWorker row_char)
return (insertion_cost', cost_row', cost_row'', cost_row, row_char)
where
colWorker !row_char !i = do
prev_col_char <- readArray str1_array (i 1)
col_char <- readArray str1_array i
left_left_up_up <- readArray cost_row (i 2)
left_up <- readArray cost_row' (i 1)
left <- readArray cost_row'' (i 1)
here_up <- readArray cost_row' i
let here_standard_only = standardCosts costs row_char col_char left left_up here_up
here = if prev_row_char == col_char && prev_col_char == row_char
then here_standard_only `min` (left_left_up_up + transpositionCost costs col_char row_char)
else here_standard_only
writeArray cost_row'' i here
standardCosts :: EditCosts -> Char -> Char -> Int -> Int -> Int -> Int
standardCosts !costs !row_char !col_char !cost_left !cost_left_up !cost_up = deletion_cost `min` insertion_cost `min` subst_cost
where
deletion_cost = cost_left + deletionCost costs col_char
insertion_cost = cost_up + insertionCost costs row_char
subst_cost = cost_left_up + if row_char == col_char then 0 else substitutionCost costs col_char row_char
stringToArray :: String -> Int -> ST s (STUArray s Int Char)
stringToArray str !str_length = do
array <- newArray_ (1, str_length)
forM_ (zip [1..] str) (uncurry (writeArray array))
return array