LocalCat API
Translation
- class localcat.Translate.Translate(model_name_or_path='facebook/mbart-large-50-many-to-one-mmt', src_lang='zh_CN', tgt_lang='en_XX')[source]
Translate class for language translation.
- compute_metrics(eval_preds)[source]
Computes evaluation metrics for machine translation model predictions.
- Parameters:
eval_preds (tuple) – Tuple containing predicted and label sequences.
- Returns:
Dictionary containing computed evaluation metrics.
- Return type:
dict
- finetune(df, train_size=0.9, col_src='Chinese', col_tgt='English', max_length_input=512, max_length_target=512, prefix='', finetuned_model_path='model', batch_size=4, save_total_limit=3, evaluation_strategy='epoch', learning_rate=2e-05, weight_decay=0.01, num_train_epochs=1, compute_metrics=True)[source]
Fine-tunes a pre-trained Seq2Seq model on a custom dataset.
- Parameters:
df (pandas.DataFrame) – DataFrame containing source and target language columns.
train_size (float, optional) – Proportion of data to use for training (default: 0.9).
col_src (str, optional) – Column name for source language text (default: ‘Chinese’).
col_tgt (str, optional) – Column name for target language text (default: ‘English’).
max_length_input (int, optional) – Maximum length of input sequences (default: 512).
max_length_target (int, optional) – Maximum length of target sequences (default: 512).
prefix (str, optional) – String to prepend to each source language sentence (default: ‘’).
finetuned_model_path (str, optional) – Path to save the fine-tuned model (default: “model”).
batch_size (int, optional) – Batch size for training and evaluation (default: 4).
- Returns:
Saves the fine-tuned model and prints evaluation results.
- Return type:
None
This function fine-tunes a pre-trained Seq2Seq model on a given dataset for machine translation. It performs the following steps:
Generates a training dataset from the provided DataFrame.
Tokenizes the dataset using the specified parameters and prefix.
Defines training arguments for the fine-tuning process.
Creates a data collator for efficient batch processing.
Initializes a Seq2SeqTrainer object with the model, arguments, and datasets.
Trains the model on the training dataset.
Saves the fine-tuned model to the specified path.
Evaluates the model on the test dataset and prints the results.
- generate_dataset(df, train_size=0.9, col_src='Chinese', col_tgt='English')[source]
Generates a DatasetDict for machine translation from a pandas DataFrame.
- Parameters:
df – A pandas DataFrame containing the source and target language columns.
train_size – The proportion of the data to use for training (default: 0.9).
col_src – The name of the source language column (default: ‘Chinese’).
col_tgt – The name of the target language column (default: ‘English’).
- Returns:
A DatasetDict containing the training, validation, and test datasets.
- postprocess_text(preds, labels)[source]
Removes leading and trailing whitespaces from predicted and label text sequences.
- Parameters:
preds (list) – List of predicted text sequences.
labels (list) – List of label text sequences.
- Returns:
A tuple containing the postprocessed predicted and label sequences.
- Return type:
tuple
- tokenize_dataset(max_length_input=512, max_length_target=512, prefix='')[source]
Preprocesses a dataset of text pairs for machine translation models.
- Parameters:
max_length_input (int) – Maximum length of the input sequence (default: 512).
max_length_target (int) – Maximum length of the target sequence (default: 512).
prefix (str, optional) – String to prepend to each source language sentence (default: ‘’). T5 model requires a special prefix to put before the inputs, you should adopt the following code for defining the prefix. For mBART and MarianMT prefixes will remain blank.
- Returns:
The original dataset transformed with tokenized input, target sequences, and labels.
- Return type:
datasets.Dataset
- translator(text, max_new_tokens=500)[source]
Translates text from source language to target language.
- Parameters:
text (str) – Text to be translated.
max_new_tokens (int) – The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. Default is 500.
- Returns:
Translated text.
- Return type:
str
- translator_batch(df, col_src='Chinese', col_tgt='English')[source]
Translate a batch of text from one language to another using a provided translation function.
- Parameters:
df (pd.DataFrame) – The Pandas DataFrame containing the text to translate.
col_src (str, optional) – The name of the column containing the source language text. Defaults to “Chinese”.
col_tgt (str, optional) – The name of the column to store the translated text. Defaults to “English”.
- Returns:
The original DataFrame with the translated text added to the specified target column.
- Return type:
pd.DataFrame
- Raises:
AttributeError – If the specified columns (col_src or col_tgt) are not found in the DataFrame.
- Prints:
The total time taken for the translation and the average speed per item.
- class localcat.Translate.Local(model_name=None, model_path=None)[source]
Local class for model deployment in AWS.
- deploy(instance_type='ml.g4dn.4xlarge', transformers_version='4.37.0', pytorch_version='2.1.0', py_version='py310')[source]
Deploys the HuggingFace model to an Amazon SageMaker endpoint.
- Parameters:
instance_type (str) – The type of Amazon SageMaker instance to use for deployment. Default is ‘ml.g4dn.4xlarge’.
transformers_version (str) – The version of the Transformers library to use. Default is ‘4.37.0’.
pytorch_version (str) – The version of PyTorch to use. Default is ‘2.1.0’.
py_version (str) – The version of Python to use. Default is ‘py310’.
- Returns:
None
- push_to_s3(bucket, prefix=None)[source]
Pushes the model to an S3 bucket.
- Parameters:
bucket (str) – The name of the S3 bucket.
prefix (str, optional) – The prefix to be added to the S3 key.
- Returns:
None
- translator(text)[source]
Translates the given text using the HuggingFace model.
- Parameters:
text (str) – The text to be translated.
- Returns:
The translated text.
- Return type:
str
- translator_batch(df, col_src='Chinese', col_tgt='English')[source]
Translates a batch of text in a DataFrame column using the translator method.
- Parameters:
df (pandas.DataFrame) – The DataFrame containing the text to be translated.
col_src (str, optional) – The name of the source column containing the text to be translated. Defaults to ‘Chinese’.
col_tgt (str, optional) – The name of the target column to store the translated text. Defaults to ‘English’.
- Returns:
The DataFrame with the translated text in the target column.
- Return type:
pandas.DataFrame