Skip to main content
Traditional pipelines segment text first, then tag the results — propagating early segmentation errors into downstream POS tags. This module eliminates that cascade by searching over word boundaries and POS labels in a single Viterbi pass, finding the globally optimal combination.

Overview

from myspellchecker.algorithms.joint_segment_tagger import JointSegmentTagger

tagger = JointSegmentTagger(
    provider,
    pos_bigram_probs,
    pos_trigram_probs,
)

words, tags = tagger.segment_and_tag("မြန်မာနိုင်ငံ")
print(list(zip(words, tags)))
# [('မြန်မာ', 'N'), ('နိုင်ငံ', 'N')]

The Joint Optimization Problem

Traditional approach (sequential):
Text → Segment → Words → Tag → POS
       ↓          ↓
    (errors)  (propagate)
Joint approach (this module):
Text → Optimize(Segment + Tag) → Words + POS

        (global optimum)

Benefits

AspectSequentialJoint
OptimizationLocalGlobal
Error propagationYesMinimal
PassesMultipleSingle
Ambiguity handlingLimitedBetter

Mathematical Formulation

The tagger finds:
argmax P(words, tags | text)
= argmax Π P(word_i) × P(tag_i | tag_{i-1}, tag_{i-2}) × P(tag_i | word_i)
In log space:
= argmax Σ [log P(word_i) + log P(tag_i | tags) + log P(tag_i | word_i)]
Components:
  • P(word_i) - Word n-gram probability
  • P(tag_i | tags) - Tag transition probability (HMM)
  • P(tag_i | word_i) - Emission probability

JointSegmentTagger Class

class JointSegmentTagger:
    """Joint word segmentation and POS tagging using unified Viterbi."""

    def __init__(
        self,
        provider: DictionaryProvider,
        pos_bigram_probs: Dict[Tuple[str, str], float],
        pos_trigram_probs: Dict[Tuple[str, str, str], float],
        pos_unigram_probs: Optional[Dict[str, float]] = None,
        word_tag_probs: Optional[Dict[str, Dict[str, float]]] = None,
        min_prob: float = 1e-10,
        max_word_length: int = 20,
        beam_width: int = 15,
        emission_weight: float = 1.2,
        word_score_weight: float = 1.0,
        use_morphology_fallback: bool = True,
    ):
        ...

Parameters

ParameterDefaultDescription
providerRequiredDictionary provider for word lookups
pos_bigram_probsRequiredP(tagprev_tag) transitions
pos_trigram_probsRequiredP(tagprev2, prev1) trigrams
pos_unigram_probsNoneP(tag) priors for fallback
word_tag_probsNoneP(tagword) emissions
min_prob1e-10Minimum probability for smoothing
max_word_length20Maximum word length in chars
beam_width15Beam size for pruning
emission_weight1.2Weight for emission scores
word_score_weight1.0Weight for word n-gram scores
use_morphology_fallbackTrueUse morphology for OOV words

Usage

Basic Segmentation and Tagging

# Create tagger
tagger = JointSegmentTagger(
    provider=provider,
    pos_bigram_probs=bigram_probs,
    pos_trigram_probs=trigram_probs,
)

# Process text
words, tags = tagger.segment_and_tag("မြန်မာနိုင်ငံသည်အရှေ့တောင်အာရှတွင်တည်ရှိသည်")

for word, tag in zip(words, tags):
    print(f"{word}\t{tag}")
# မြန်မာ     N
# နိုင်ငံ    N
# သည်       P_SENT
# အရှေ့တောင် N
# အာရှ      N
# တွင်      PPM
# တည်ရှိ    V
# သည်       P_SENT

Batch Processing

texts = [
    "မြန်မာနိုင်ငံ",
    "ကျေးဇူးတင်ပါသည်",
    "ဘာလဲ",
]

results = tagger.segment_and_tag_batch(texts)

for text, (words, tags) in zip(texts, results):
    print(f"Text: {text}")
    print(f"Words: {words}")
    print(f"Tags: {tags}")
    print()

State Space

The Viterbi algorithm operates on states:
@dataclass
class JointState:
    """State in the joint segmentation-tagging lattice."""

    word_start: int       # Character index where current word starts
    current_tag: str      # POS tag for current word
    prev_tag: str         # POS tag for previous word
    score: float          # Log probability score
    backpointer: Optional["JointState"]  # Previous state
State space: (position, word_start, current_tag, prev_tag)

Scoring Functions

Word Score

def _get_word_score(self, word: str, prev_word: str) -> float:
    """Get word n-gram score: log P(word | prev_word)."""
    # Try bigram first
    bigram_prob = self.provider.get_bigram_probability(prev_word, word)
    if bigram_prob > self.min_prob:
        return self.word_score_weight * math.log(bigram_prob)

    # Fallback to unigram
    freq = self.provider.get_word_frequency(word)
    if freq > 0:
        return self.word_score_weight * math.log(freq / 1e6)

    # Unknown word penalty
    return self.word_score_weight * (self.log_min_prob - len(word) * 0.5)

Tag Transition Score

def _get_tag_transition_score(self, tag: str, prev_tag: str, prev_prev_tag: str) -> float:
    """Get POS tag transition score: log P(tag | prev_prev_tag, prev_tag)."""
    # Try trigram
    trigram_prob = self.pos_trigram_probs.get((prev_prev_tag, prev_tag, tag), 0.0)
    if trigram_prob > self.min_prob:
        return math.log(trigram_prob)

    # Fallback to bigram
    bigram_prob = self.pos_bigram_probs.get((prev_tag, tag), 0.0)
    if bigram_prob > self.min_prob:
        return math.log(bigram_prob)

    # Fallback to unigram
    unigram_prob = self.pos_unigram_probs.get(tag, self.min_prob)
    return math.log(unigram_prob)

Emission Score

def _get_emission_score(self, word: str, tag: str) -> float:
    """Get emission score: log P(tag | word)."""
    if word in self.word_tag_probs:
        prob = self.word_tag_probs[word].get(tag, self.min_prob)
        return self.emission_weight * math.log(prob)

    # Fallback to tag prior
    if self.pos_unigram_probs:
        prob = self.pos_unigram_probs.get(tag, self.min_prob)
        return self.emission_weight * math.log(prob)

    return 0.0

Beam Pruning

To manage the large state space, beam pruning keeps only top-k states:
def _prune_beam(self, dp, end_pos):
    """Apply beam pruning to keep only top-k states."""
    if len(dp[end_pos]) > self.beam_width:
        top_states = nlargest(
            self.beam_width,
            dp[end_pos].items(),
            key=lambda x: x[1][0],  # Sort by score
        )
        dp[end_pos] = dict(top_states)

OOV Handling

For out-of-vocabulary words, the tagger uses morphological analysis:
def _get_valid_tags_for_word(self, word: str) -> Set[str]:
    """Get valid POS tags for a word."""
    tags = set()

    # From database
    pos_str = self.provider.get_word_pos(word)
    if pos_str:
        tags.update(pos_str.split("|"))

    # From word-tag probabilities
    if word in self.word_tag_probs:
        tags.update(self.word_tag_probs[word].keys())

    # Morphological fallback for OOV
    if not tags and self.morphology_analyzer:
        morpho_tags = self.morphology_analyzer.guess_pos(word)
        if morpho_tags:
            tags.update(morpho_tags)

    # Final fallback
    if not tags:
        return {self.UNKNOWN_TAG}

    return tags

Performance

Complexity

  • Time: O(n × W × T²) where n=length, W=max_word_length, T=num_tags
  • Space: O(n × beam_width)

Benchmarks

Text LengthSequentialJointSpeedup
50 chars5ms8ms0.6x
200 chars20ms25ms0.8x
1000 chars100ms90ms1.1x
Joint is slightly slower for short texts but comparable for longer texts, with better accuracy.

Cache Management

# Clear cache when done
tagger.clear_cache()

Integration

With SpellChecker

from myspellchecker import SpellChecker

class SpellChecker:
    def __init__(self, joint_tagger: JointSegmentTagger = None):
        self.tagger = joint_tagger

    def check(self, text: str):
        # Use joint segmentation and tagging
        words, tags = self.tagger.segment_and_tag(text)

        # Check spelling with POS context
        errors = []
        for i, (word, tag) in enumerate(zip(words, tags)):
            if not self._is_valid(word, tag):
                errors.append(self._create_error(word, tag, i))

        return errors

See Also