Skip to content

LMOutput

LMOutput dataclass

Source code in flexeval/core/language_model/base.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@dataclass
class LMOutput:
    text: str
    """
    The output text of the language model.
    """
    raw_text: str | None = None
    """
    The raw output text of the language model before post-processing.
    """
    finish_reason: str | None = None
    """
    The reason why the generation is finished.
    Typically,
    - 'stop': A stop sequence is generated.
    - 'length': The maximum length is reached.
    """

text instance-attribute

text: str

The output text of the language model.

raw_text class-attribute instance-attribute

raw_text: str | None = None

The raw output text of the language model before post-processing.

finish_reason class-attribute instance-attribute

finish_reason: str | None = None

The reason why the generation is finished. Typically, - 'stop': A stop sequence is generated. - 'length': The maximum length is reached.

__init__

__init__(
    text: str,
    raw_text: str | None = None,
    finish_reason: str | None = None,
) -> None

LanguageModel

LanguageModel is what you want to evaluate with this library.

It can generate text based on the input text, response to chat messages, and compute log probabilities.

Parameters:

  • string_processors (StringProcessor | list[StringProcessor] | None, default: None ) –

    A single or a list of StringProcessor objects to process the model's output.

Source code in flexeval/core/language_model/base.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
class LanguageModel:
    """LanguageModel is what you want to evaluate with this library.

    It can generate text based on the input text, response to chat messages, and compute log probabilities.

    Args:
        string_processors: A single or a list of StringProcessor objects to process the model's output.

    """

    def __init__(self, string_processors: StringProcessor | list[StringProcessor] | None = None) -> None:
        if string_processors is None:
            string_processors = []
        elif isinstance(string_processors, StringProcessor):
            string_processors = [string_processors]

        self.string_processors = string_processors

    def _batch_complete_text(
        self,
        text_list: list[str],
        stop_sequences: str | list[str] | None = None,
        max_new_tokens: int | None = None,
        **kwargs,
    ) -> list[LMOutput]:
        """
        Generate text based on the input text list.

        Args:
            text_list: A list of input texts.
            stop_sequences: A string or a list of strings that will stop the generation when they are generated.
                This argument exists to give a common interface to various models that have different names for it.
            max_new_tokens: The maximum number of tokens to generate for each text.
                This argument exists to give a common interface to various models that have different names for it.
            **kwargs: Additional keyword arguments for text generation.
                The acceptable keys depend on the specific implementation of the model.
                These arguments override corresponding values in the model's default_gen_kwargs.
                Special cases:
                - 'stop_sequences' or any similar model-specific kwargs:
                    Merged with default_gen_kwargs instead of overriding.
        """
        msg = f"{self.__class__.__name__} cannot generate text."
        raise NotImplementedError(msg)

    def _batch_generate_chat_response(
        self,
        chat_messages_list: list[list[dict[str, Any]]],
        **kwargs,
    ) -> list[LMOutput]:
        """Generate chat responses based on the chat messages in the list.
        This method is used for chatbot models.

        Args:
            chat_messages_list: A list of chat messages.
        """
        msg = f"{self.__class__.__name__} cannot generate chat responses."
        raise NotImplementedError(msg)

    def _batch_compute_log_probs(
        self,
        text_list: list[str],
        prefix_list: list[str] | None = None,
        stride: int | None = None,
    ) -> list[float]:
        """
        Compute log probabilities of the text list.
        Used for compute perplexity of text, or solving multiple choice questions.

        Args:
            text_list: A list of texts to compute log probabilities.
            prefix_list: A list of prefixes for each text.
            stride: The stride for computing log probabilities.
        """
        msg = f"{self.__class__.__name__} cannot compute perplexity."
        raise NotImplementedError(msg)

    def _batch_compute_chat_log_probs(
        self, prompt_list: list[list[dict[str, Any]]], response_list: list[dict[str, Any]]
    ) -> list[float]:
        """
        Compute log probabilities of the chat responses given the chat history.

        Args:
            prompt_list: A list of chat histories.
            response_list: A list of chat responses.
        """
        msg = f"{self.__class__.__name__} cannot compute chat log probabilities."
        raise NotImplementedError(msg)

    @final
    def complete_text(
        self,
        text: str | list[str],
        stop_sequences: str | list[str] | None = None,
        max_new_tokens: int | None = None,
        **kwargs,
    ) -> LMOutput | list[LMOutput]:
        """
        A wrapper for `batch_complete_text` that accepts a single text or a list of texts.
        This is a convenient method for end-users.
        To implement generation logic, you should override `batch_complete_text` method.
        """

        # Normalize the input text
        text_list = text
        if isinstance(text, str):
            text_list = [text]

        lm_outputs = self._batch_complete_text(
            text_list, stop_sequences=stop_sequences, max_new_tokens=max_new_tokens, **kwargs
        )

        # Post-process the generated text
        if self.string_processors:
            for lm_output in lm_outputs:
                lm_output.raw_text = lm_output.text
                for string_processor in self.string_processors:
                    lm_output.text = string_processor(lm_output.text)

        # Return the result
        if isinstance(text, str):
            return lm_outputs[0]
        return lm_outputs

    @final
    def generate_chat_response(
        self,
        chat_messages: list[dict[str, Any]] | list[list[dict[str, Any]]],
        **kwargs,
    ) -> LMOutput | list[LMOutput]:
        """
        A wrapper for `batch_generate_chat_response` that accepts a single chat message or a list of chat messages.
        This is a convenient method for end-users.
        To implement generation logic, you should override `batch_generate_chat_response` method.
        """

        chat_messages_list = chat_messages
        if isinstance(chat_messages[0], dict):
            chat_messages_list = [chat_messages]

        lm_outputs = self._batch_generate_chat_response(chat_messages_list, **kwargs)

        # Post-process the generated text
        if self.string_processors:
            for lm_output in lm_outputs:
                lm_output.raw_text = lm_output.text
                for string_processor in self.string_processors:
                    lm_output.text = string_processor(lm_output.text)

        # Return the result
        if isinstance(chat_messages[0], dict):
            return lm_outputs[0]
        return lm_outputs

    @final
    def compute_log_probs(
        self,
        text_list: str | list[str],
        prefix_list: list[str] | None = None,
        stride: int | None = None,
    ) -> float | list[float]:
        """
        A wrapper for `batch_compute_log_probs` that accepts a single text or a list of texts.
        This is a convenient method for end-users.
        To implement computation logic, you should override `batch_compute_log_probs` method.
        """

        if isinstance(text_list, str):
            return self._batch_compute_log_probs([text_list], prefix_list=prefix_list, stride=stride)[0]
        return self._batch_compute_log_probs(text_list, prefix_list=prefix_list, stride=stride)

    @final
    def compute_chat_log_probs(
        self, prompt: list[dict[str, Any]] | list[list[dict[str, Any]]], response: dict[str, Any] | list[dict[str, Any]]
    ) -> float | list[float]:
        """
        A wrapper for `batch_compute_chat_log_probs` that accepts a single chat prompt or a list of chat prompts.
        This is a convenient method for end-users.
        To implement computation logic, you should override `batch_compute_chat_log_probs` method.
        """

        if isinstance(prompt[0], dict):
            return self._batch_compute_chat_log_probs([prompt], [response])[0]
        return self._batch_compute_chat_log_probs(prompt, response)

string_processors instance-attribute

string_processors = string_processors

__init__

__init__(
    string_processors: StringProcessor
    | list[StringProcessor]
    | None = None,
) -> None
Source code in flexeval/core/language_model/base.py
38
39
40
41
42
43
44
def __init__(self, string_processors: StringProcessor | list[StringProcessor] | None = None) -> None:
    if string_processors is None:
        string_processors = []
    elif isinstance(string_processors, StringProcessor):
        string_processors = [string_processors]

    self.string_processors = string_processors

complete_text

complete_text(
    text: str | list[str],
    stop_sequences: str | list[str] | None = None,
    max_new_tokens: int | None = None,
    **kwargs,
) -> LMOutput | list[LMOutput]

A wrapper for batch_complete_text that accepts a single text or a list of texts. This is a convenient method for end-users. To implement generation logic, you should override batch_complete_text method.

Source code in flexeval/core/language_model/base.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
@final
def complete_text(
    self,
    text: str | list[str],
    stop_sequences: str | list[str] | None = None,
    max_new_tokens: int | None = None,
    **kwargs,
) -> LMOutput | list[LMOutput]:
    """
    A wrapper for `batch_complete_text` that accepts a single text or a list of texts.
    This is a convenient method for end-users.
    To implement generation logic, you should override `batch_complete_text` method.
    """

    # Normalize the input text
    text_list = text
    if isinstance(text, str):
        text_list = [text]

    lm_outputs = self._batch_complete_text(
        text_list, stop_sequences=stop_sequences, max_new_tokens=max_new_tokens, **kwargs
    )

    # Post-process the generated text
    if self.string_processors:
        for lm_output in lm_outputs:
            lm_output.raw_text = lm_output.text
            for string_processor in self.string_processors:
                lm_output.text = string_processor(lm_output.text)

    # Return the result
    if isinstance(text, str):
        return lm_outputs[0]
    return lm_outputs

generate_chat_response

generate_chat_response(
    chat_messages: list[dict[str, Any]]
    | list[list[dict[str, Any]]],
    **kwargs,
) -> LMOutput | list[LMOutput]

A wrapper for batch_generate_chat_response that accepts a single chat message or a list of chat messages. This is a convenient method for end-users. To implement generation logic, you should override batch_generate_chat_response method.

Source code in flexeval/core/language_model/base.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@final
def generate_chat_response(
    self,
    chat_messages: list[dict[str, Any]] | list[list[dict[str, Any]]],
    **kwargs,
) -> LMOutput | list[LMOutput]:
    """
    A wrapper for `batch_generate_chat_response` that accepts a single chat message or a list of chat messages.
    This is a convenient method for end-users.
    To implement generation logic, you should override `batch_generate_chat_response` method.
    """

    chat_messages_list = chat_messages
    if isinstance(chat_messages[0], dict):
        chat_messages_list = [chat_messages]

    lm_outputs = self._batch_generate_chat_response(chat_messages_list, **kwargs)

    # Post-process the generated text
    if self.string_processors:
        for lm_output in lm_outputs:
            lm_output.raw_text = lm_output.text
            for string_processor in self.string_processors:
                lm_output.text = string_processor(lm_output.text)

    # Return the result
    if isinstance(chat_messages[0], dict):
        return lm_outputs[0]
    return lm_outputs

compute_log_probs

compute_log_probs(
    text_list: str | list[str],
    prefix_list: list[str] | None = None,
    stride: int | None = None,
) -> float | list[float]

A wrapper for batch_compute_log_probs that accepts a single text or a list of texts. This is a convenient method for end-users. To implement computation logic, you should override batch_compute_log_probs method.

Source code in flexeval/core/language_model/base.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
@final
def compute_log_probs(
    self,
    text_list: str | list[str],
    prefix_list: list[str] | None = None,
    stride: int | None = None,
) -> float | list[float]:
    """
    A wrapper for `batch_compute_log_probs` that accepts a single text or a list of texts.
    This is a convenient method for end-users.
    To implement computation logic, you should override `batch_compute_log_probs` method.
    """

    if isinstance(text_list, str):
        return self._batch_compute_log_probs([text_list], prefix_list=prefix_list, stride=stride)[0]
    return self._batch_compute_log_probs(text_list, prefix_list=prefix_list, stride=stride)

compute_chat_log_probs

compute_chat_log_probs(
    prompt: list[dict[str, Any]]
    | list[list[dict[str, Any]]],
    response: dict[str, Any] | list[dict[str, Any]],
) -> float | list[float]

A wrapper for batch_compute_chat_log_probs that accepts a single chat prompt or a list of chat prompts. This is a convenient method for end-users. To implement computation logic, you should override batch_compute_chat_log_probs method.

Source code in flexeval/core/language_model/base.py
199
200
201
202
203
204
205
206
207
208
209
210
211
@final
def compute_chat_log_probs(
    self, prompt: list[dict[str, Any]] | list[list[dict[str, Any]]], response: dict[str, Any] | list[dict[str, Any]]
) -> float | list[float]:
    """
    A wrapper for `batch_compute_chat_log_probs` that accepts a single chat prompt or a list of chat prompts.
    This is a convenient method for end-users.
    To implement computation logic, you should override `batch_compute_chat_log_probs` method.
    """

    if isinstance(prompt[0], dict):
        return self._batch_compute_chat_log_probs([prompt], [response])[0]
    return self._batch_compute_chat_log_probs(prompt, response)

HuggingFaceLM

LanguageModel implementation using Hugging Face Transformers.

Parameters:

  • model (str) –

    The model name or path of the Hugging Face model.

  • model_kwargs (dict[str, Any] | None, default: None ) –

    Keyword arguments for the model instantiation by from_pretrained().

  • tokenizer (str | None, default: None ) –

    The tokenizer name or path of the Hugging Face tokenizer.

  • tokenizer_kwargs (dict[str, Any] | None, default: None ) –

    Keyword arguments for the tokenizer instantiation by `from_pretrained().

  • add_special_tokens (bool, default: False ) –

    Whether to add special tokens to the input. Note that whether BOS or EOS tokens are added depends on the tokenizer.

  • amp_dtype (Literal['float16', 'bfloat16'] | None, default: None ) –

    The dtype for automatic mixed precision.

  • random_seed (int, default: 42 ) –

    Random seed for the model.

  • load_peft (bool, default: False ) –

    Should be set to True when loading the model from PEFT weights.

  • custom_chat_template (str | None, default: None ) –

    A custom chat template for chatbot models. If specified, this overrides the default chat template of the tokenizer.

  • default_gen_kwargs (dict[str, Any] | None, default: None ) –

    Default generation kwargs to use when calling the API.

  • string_processors (StringProcessor | list[StringProcessor] | None, default: None ) –

    A single or a list of StringProcessor objects to process the model's output.

  • model_limit_tokens (int | None | Literal['default'], default: 'default' ) –

    An upper limit on the number of tokens (input + output) the model can handle. If max_new_tokens exceeds this limit in generate_chat_response(), it will be capped to this value. If this value is set to less than or equal to the model's capacity and the input exceeds it, an empty string is returned instead of raising an error. If set to “default”, the value will be automatically determined when possible.

Source code in flexeval/core/language_model/hf_lm.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
class HuggingFaceLM(LanguageModel):
    """
    LanguageModel implementation using Hugging Face Transformers.

    Args:
        model: The model name or path of the Hugging Face model.
        model_kwargs: Keyword arguments for the model instantiation by `from_pretrained()`.
        tokenizer: The tokenizer name or path of the Hugging Face tokenizer.
        tokenizer_kwargs: Keyword arguments for the tokenizer instantiation by `from_pretrained().
        add_special_tokens: Whether to add special tokens to the input.
            Note that whether BOS or EOS tokens are added depends on the tokenizer.
        amp_dtype: The dtype for automatic mixed precision.
        random_seed: Random seed for the model.
        load_peft: Should be set to True when loading the model from PEFT weights.
        custom_chat_template: A custom chat template for chatbot models.
            If specified, this overrides the default chat template of the tokenizer.
        default_gen_kwargs: Default generation kwargs to use when calling the API.
        string_processors: A single or a list of StringProcessor objects to process the model's output.
        model_limit_tokens: An upper limit on the number of tokens (input + output) the model can handle.
            If `max_new_tokens` exceeds this limit in `generate_chat_response()`, it will be capped to this value.
            If this value is set to less than or equal to the model's capacity and the input exceeds it,
            an empty string is returned instead of raising an error.
            If set to “default”, the value will be automatically determined when possible.
    """

    def __init__(
        self,
        model: str,
        model_kwargs: dict[str, Any] | None = None,
        tokenizer: str | None = None,
        tokenizer_kwargs: dict[str, Any] | None = None,
        add_special_tokens: bool = False,
        amp_dtype: Literal["float16", "bfloat16"] | None = None,
        random_seed: int = 42,
        load_peft: bool = False,
        custom_chat_template: str | None = None,
        default_gen_kwargs: dict[str, Any] | None = None,
        string_processors: StringProcessor | list[StringProcessor] | None = None,
        model_limit_tokens: int | None | Literal["default"] = "default",
    ) -> None:
        super().__init__(string_processors=string_processors)
        self._model_name_or_path = model
        tokenizer = tokenizer if tokenizer else model
        tokenizer_kwargs = tokenizer_kwargs or {}
        self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs)
        self.custom_chat_template = custom_chat_template
        self.add_special_tokens = add_special_tokens
        self.default_gen_kwargs = default_gen_kwargs or {}

        model_kwargs = get_default_model_kwargs(model_kwargs)
        if not load_peft:
            self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
                model,
                **model_kwargs,
            )
        else:
            from peft import AutoPeftModelForCausalLM

            self.model = AutoPeftModelForCausalLM.from_pretrained(
                model,
                **model_kwargs,
            )

        self.model.eval()

        self.amp_dtype = amp_dtype
        if model_limit_tokens == "default":
            hf_config = self.model.config.to_dict()
            if "n_positions" in hf_config:
                model_limit_tokens = hf_config["n_positions"]
            elif "max_position_embeddings" in hf_config:
                model_limit_tokens = hf_config["max_position_embeddings"]
            else:
                msg = (
                    "`model_limit_tokens` was set to “default”, but the default max_position_embedeings "
                    "could not be found in the config. Set it to `None`."
                )
                logger.warn(msg)
        self.model_limit_tokens = model_limit_tokens

        transformers.set_seed(random_seed)

        logger.info(f"model device: {self.model.device}")
        logger.info(f"model dtype: {self.model.dtype}")
        logger.info(f"amp_dtype: {amp_dtype}")
        logger.info(f"random seed: {random_seed}")

    def _get_amp_context(self) -> contextlib.AbstractContextManager:
        if self.amp_dtype is None:
            return contextlib.nullcontext()
        if self.amp_dtype == "float16":
            return torch.amp.autocast(
                device_type=self.model.device.type,
                dtype=torch.float16,
            )
        if self.amp_dtype == "bfloat16":
            return torch.amp.autocast(
                device_type=self.model.device.type,
                dtype=torch.bfloat16,
            )

        msg = f"Invalid amp_dtype: {self.amp_dtype}"
        raise ValueError(msg)

    def _get_stop_token_ids(self, stop_sequences: list[str]) -> list[int]:
        stop_token_ids: list[int] = []
        for stop_seq in stop_sequences:
            # Try to convert string to id using `convert_tokens_to_ids`
            # We do not use the `encode` method
            # because in the case of sentencepiece-based tokenizers,
            # calling the encode method adds a redundant space at the beginning of the string,
            stop_token_id = self.tokenizer.convert_tokens_to_ids(stop_seq)

            # NeoXTokenizer returns Unk when calling convert_tokens_ids
            # because each token is stored in a peculiar way
            # Ex. "」" -> "ãĢį"
            if stop_token_id == self.tokenizer.unk_token_id:
                # In such a case, we try to get the ID by calling the encode method.
                stop_seq_tokens = self.tokenizer.encode(stop_seq, add_special_tokens=False)
                if stop_seq_tokens:
                    stop_token_id = stop_seq_tokens[-1]
            # If the token does not match the specified string itself, we do not include it as a stop token id
            if self.tokenizer.decode(stop_token_id) != stop_seq:
                continue

            stop_token_ids.append(stop_token_id)
        return stop_token_ids

    @torch.inference_mode()
    def _batch_complete_text(
        self,
        text_list: list[str],
        stop_sequences: str | list[str] | None = None,
        max_new_tokens: int | None = None,
        ignore_eos: bool = False,
        **kwargs,
    ) -> list[LMOutput]:
        gen_kwargs = self.default_gen_kwargs.copy()
        gen_kwargs.update(kwargs)
        if max_new_tokens is not None:
            gen_kwargs["max_new_tokens"] = max_new_tokens

        model_inputs = tokenize_text_for_lm_prefix(
            text_list,
            self.tokenizer,
            add_special_tokens=self.add_special_tokens,
        ).to(self.model.device)
        input_token_length = model_inputs["input_ids"].shape[1]

        if self.model_limit_tokens:
            model_limit_new_tokens = self.model_limit_tokens - input_token_length
            if model_limit_new_tokens <= 0:
                msg = (
                    f"Received input that is longer than `model_limit_tokens = {self.model_limit_tokens}`. "
                    f"This batch returns empty strings."
                )
                logger.warning(msg)
                return [LMOutput(text="", finish_reason="input_length_limit") for _ in text_list]

            if "max_new_tokens" not in gen_kwargs or model_limit_new_tokens < gen_kwargs["max_new_tokens"]:
                gen_kwargs["max_new_tokens"] = model_limit_new_tokens

        # set the stop sequences
        stop_sequences = normalize_stop_sequences(
            stop_sequences_list=[
                stop_sequences,
                gen_kwargs.pop("stop_strings", None),  # This is used in the transformers `generate` function
                gen_kwargs.pop("stop_sequences", None),  # This is a common variable name used in flexeval
            ],
            bos_token=self.tokenizer.bos_token,
            eos_token=self.tokenizer.eos_token,
            ignore_eos=ignore_eos,
        )
        stop_token_ids = self._get_stop_token_ids(stop_sequences)
        gen_kwargs.update(
            {
                "eos_token_id": stop_token_ids,
                "pad_token_id": self.tokenizer.pad_token_id,
            },
        )

        with self._get_amp_context():
            generated_tokens = self.model.generate(**model_inputs, **gen_kwargs)

        # We strip the input text and stop sequences from the output text.
        lm_outputs: list[LMOutput] = []
        for generated_tensor in generated_tokens:
            input_tensor = generated_tensor[:input_token_length]
            output_tensor = generated_tensor[input_token_length:]

            input_tokens = [t for t in input_tensor.tolist() if t != self.tokenizer.pad_token_id]
            output_tokens = [t for t in output_tensor.tolist() if t != self.tokenizer.pad_token_id]
            decoded_text = decode_for_lm_continuation(output_tokens, input_tokens, self.tokenizer)

            finish_reason = "length"
            for stop_seq in stop_sequences:
                idx = decoded_text.find(stop_seq)
                if idx != -1:
                    decoded_text = decoded_text[:idx]
                    finish_reason = "stop"
            lm_outputs.append(LMOutput(text=decoded_text, finish_reason=finish_reason))
        return lm_outputs

    def _batch_generate_chat_response(
        self,
        chat_messages_list: list[list[dict[str, Any]]],
        **kwargs,
    ) -> list[LMOutput]:
        chat_messages_as_string = [
            self.tokenizer.apply_chat_template(
                chat_messages,
                tokenize=False,
                add_generation_prompt=True,
                chat_template=self.custom_chat_template,
            )
            for chat_messages in chat_messages_list
        ]
        return self._batch_complete_text(chat_messages_as_string, **kwargs)

    @torch.inference_mode()
    def _batch_compute_log_probs(
        self,
        text_list: list[str],
        prefix_list: list[str] | None = None,
        stride: int | None = None,
    ) -> list[float]:
        batch_size = len(text_list)

        # prepare prefix encoding
        prefix_list = prefix_list if prefix_list else ["" for _ in range(batch_size)]
        # If the prefix is an empty string, replace it with the bos token regardless of the model being trained with it.
        # This is needed to correctly calculate the log probabilities of the first token.
        for i in range(batch_size):
            if prefix_list[i] == "":
                prefix_list[i] = self.tokenizer.bos_token

        prefix_encoding = tokenize_text_for_lm_prefix(
            prefix_list,
            self.tokenizer,
            add_special_tokens=self.add_special_tokens,
        )

        # prepare continuation encoding
        # If the last token is a special token, it is treated as a beginning of a new sentence.
        continuation_encoding = tokenize_text_for_lm_continuation(
            text_list,
            self.tokenizer,
            as_continuation=[
                prefix_ids[-1] not in self.tokenizer.all_special_ids for prefix_ids in prefix_encoding.input_ids
            ],
        )

        input_data_dict: dict[str, torch.Tensor] = {}
        for key in continuation_encoding:
            input_data_dict[key] = torch.cat(
                [prefix_encoding[key].long(), continuation_encoding[key].long()],
                dim=1,
            )
        input_encoding = BatchEncoding(input_data_dict)

        max_length = self.model.config.max_position_embeddings
        stride = stride or max_length // 2
        if not (0 < stride < max_length):
            msg = f"stride must be in (0, {max_length}), but got {stride}"
            raise ValueError(msg)
        sequence_length = input_encoding.input_ids.size(1)

        with self._get_amp_context():
            # stores log probabilities of the next token for each input token
            last_computed_index: int = 0
            log_prob_of_next = torch.zeros_like(
                input_encoding.input_ids,
                dtype=torch.float32,
            )
            for chunk_start in range(0, sequence_length, stride):
                chunk_end = min(chunk_start + max_length, sequence_length)

                # Visualize the input / output processing
                # input_encoding.input_ids: [ 0  1  2  3  4 ]
                # chunk_input_ids:          [ 0  1  2  3    ]
                # chunk_target_ids:         [    1  2  3  4 ]

                input_start = chunk_start
                input_end = chunk_end - 1

                chunk_input_ids = input_encoding.input_ids[:, input_start:input_end].to(self.model.device)
                chunk_input_mask = input_encoding.attention_mask[:, input_start:input_end].to(self.model.device)
                chunk_target_ids = input_encoding.input_ids[:, chunk_start + 1 : chunk_end].to(self.model.device)

                chunkmodel_inputs = self.model.prepare_inputs_for_generation(
                    chunk_input_ids,
                    attention_mask=chunk_input_mask,
                )
                lm_outputs = self.model.forward(**chunkmodel_inputs)

                chunk_log_probs = F.log_softmax(lm_outputs.logits, dim=-1)
                # shape of chunk_log_probs: (batch_size, sequence_length, vocab_size)
                # shape of target_ids: (batch_size, sequence_length)
                # get the log probs of the target ids
                chunk_next_log_probs = chunk_log_probs.gather(
                    dim=-1,
                    index=chunk_target_ids.unsqueeze(-1),
                ).squeeze(-1)

                log_prob_of_next[:, last_computed_index:input_end] = chunk_next_log_probs[
                    :,
                    last_computed_index - input_start :,
                ]

                last_computed_index = input_end

                if chunk_end == sequence_length:
                    break

            log_prob_mask = input_encoding.attention_mask.clone()
            # replace the last token's log prob with 0
            for i in range(log_prob_mask.shape[0]):
                last_non_pad_index = log_prob_mask[i].nonzero(as_tuple=True)[0][-1].item()
                log_prob_mask[i, last_non_pad_index] = 0
            # mask out log probs of prefix tokens
            prefix_length = prefix_encoding.input_ids.shape[1]
            if prefix_length > 0:
                log_prob_mask[:, : prefix_length - 1] = 0
            total_log_probs = (log_prob_of_next * log_prob_mask).sum(dim=-1)
        return total_log_probs.tolist()

    def _batch_compute_chat_log_probs(
        self, prompt_list: list[list[dict[str, Any]]], response_list: list[dict[str, Any]]
    ) -> list[float]:
        prompt_as_string: list[str] = []
        response_as_string: list[str] = []
        for prompt, response in zip(prompt_list, response_list):
            prompt_as_string_i, response_as_string_i = get_prefix_and_completion_from_chat(
                prompt,
                response,
                self.tokenizer,
                custom_chat_template=self.custom_chat_template,
            )
            prompt_as_string.append(prompt_as_string_i)
            response_as_string.append(response_as_string_i)
        return self._batch_compute_log_probs(response_as_string, prefix_list=prompt_as_string)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(model={self._model_name_or_path!r})"

tokenizer instance-attribute

tokenizer: PreTrainedTokenizer = from_pretrained(
    tokenizer, **tokenizer_kwargs
)

custom_chat_template instance-attribute

custom_chat_template = custom_chat_template

add_special_tokens instance-attribute

add_special_tokens = add_special_tokens

default_gen_kwargs instance-attribute

default_gen_kwargs = default_gen_kwargs or {}

model instance-attribute

model: PreTrainedModel = from_pretrained(
    model, **model_kwargs
)

amp_dtype instance-attribute

amp_dtype = amp_dtype

model_limit_tokens instance-attribute

model_limit_tokens = model_limit_tokens

__init__

__init__(
    model: str,
    model_kwargs: dict[str, Any] | None = None,
    tokenizer: str | None = None,
    tokenizer_kwargs: dict[str, Any] | None = None,
    add_special_tokens: bool = False,
    amp_dtype: Literal["float16", "bfloat16"] | None = None,
    random_seed: int = 42,
    load_peft: bool = False,
    custom_chat_template: str | None = None,
    default_gen_kwargs: dict[str, Any] | None = None,
    string_processors: StringProcessor
    | list[StringProcessor]
    | None = None,
    model_limit_tokens: int
    | None
    | Literal["default"] = "default",
) -> None
Source code in flexeval/core/language_model/hf_lm.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def __init__(
    self,
    model: str,
    model_kwargs: dict[str, Any] | None = None,
    tokenizer: str | None = None,
    tokenizer_kwargs: dict[str, Any] | None = None,
    add_special_tokens: bool = False,
    amp_dtype: Literal["float16", "bfloat16"] | None = None,
    random_seed: int = 42,
    load_peft: bool = False,
    custom_chat_template: str | None = None,
    default_gen_kwargs: dict[str, Any] | None = None,
    string_processors: StringProcessor | list[StringProcessor] | None = None,
    model_limit_tokens: int | None | Literal["default"] = "default",
) -> None:
    super().__init__(string_processors=string_processors)
    self._model_name_or_path = model
    tokenizer = tokenizer if tokenizer else model
    tokenizer_kwargs = tokenizer_kwargs or {}
    self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs)
    self.custom_chat_template = custom_chat_template
    self.add_special_tokens = add_special_tokens
    self.default_gen_kwargs = default_gen_kwargs or {}

    model_kwargs = get_default_model_kwargs(model_kwargs)
    if not load_peft:
        self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
            model,
            **model_kwargs,
        )
    else:
        from peft import AutoPeftModelForCausalLM

        self.model = AutoPeftModelForCausalLM.from_pretrained(
            model,
            **model_kwargs,
        )

    self.model.eval()

    self.amp_dtype = amp_dtype
    if model_limit_tokens == "default":
        hf_config = self.model.config.to_dict()
        if "n_positions" in hf_config:
            model_limit_tokens = hf_config["n_positions"]
        elif "max_position_embeddings" in hf_config:
            model_limit_tokens = hf_config["max_position_embeddings"]
        else:
            msg = (
                "`model_limit_tokens` was set to “default”, but the default max_position_embedeings "
                "could not be found in the config. Set it to `None`."
            )
            logger.warn(msg)
    self.model_limit_tokens = model_limit_tokens

    transformers.set_seed(random_seed)

    logger.info(f"model device: {self.model.device}")
    logger.info(f"model dtype: {self.model.dtype}")
    logger.info(f"amp_dtype: {amp_dtype}")
    logger.info(f"random seed: {random_seed}")

__repr__

__repr__() -> str
Source code in flexeval/core/language_model/hf_lm.py
482
483
def __repr__(self) -> str:
    return f"{self.__class__.__name__}(model={self._model_name_or_path!r})"

LiteLLMChatAPI

LanguageModel implementation using LiteLLM. Various APIs are available, such as OpenAI, Claude, Gemini, etc. See also: https://docs.litellm.ai/docs/providers

Parameters:

  • model (str, default: 'openai/gpt-3.5-turbo' ) –

    The name of the model to use. e.g. 'openai/gpt-3.5-turbo',

  • default_gen_kwargs (dict[str, Any] | None, default: None ) –

    Default generation kwargs to use when calling the API.

  • developer_message (str | None, default: None ) –

    Instructions to the model that are prioritized ahead of user messages. Previously called the system prompt.

  • string_processors (StringProcessor | list[StringProcessor] | None, default: None ) –

    A single or a list of StringProcessor objects to process the model's output.

  • ignore_seed (bool, default: False ) –

    If True, ignore the seed specified in default_gen_kwargs. This is an option for models that do not support seed parameters such as anthropic/claude.

  • model_limit_completion_tokens (int | None, default: None ) –

    An upper limit on the number of tokens the model can generate. For example, if a too-large max_new_tokens is given to generate_chat_response(), this value will cap it.

Source code in flexeval/core/language_model/litellm_api.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class LiteLLMChatAPI(OpenAIChatAPI):
    """
    LanguageModel implementation using LiteLLM.
    Various APIs are available, such as OpenAI, Claude, Gemini, etc.
    See also: https://docs.litellm.ai/docs/providers

    Args:
        model: The name of the model to use. e.g. 'openai/gpt-3.5-turbo',
        default_gen_kwargs: Default generation kwargs to use when calling the API.
        developer_message: Instructions to the model that are prioritized ahead of user messages.
            Previously called the system prompt.
        string_processors: A single or a list of StringProcessor objects to process the model's output.
        ignore_seed: If True, ignore the seed specified in default_gen_kwargs.
            This is an option for models that do not support seed parameters such as anthropic/claude.
        model_limit_completion_tokens: An upper limit on the number of tokens the model can generate.
            For example, if a too-large `max_new_tokens` is given to generate_chat_response(), this value will cap it.
    """

    def __init__(
        self,
        model: str = "openai/gpt-3.5-turbo",
        default_gen_kwargs: dict[str, Any] | None = None,
        developer_message: str | None = None,
        string_processors: StringProcessor | list[StringProcessor] | None = None,
        ignore_seed: bool = False,
        model_limit_completion_tokens: int | None = None,
    ) -> None:
        super().__init__(
            model=model,
            api_headers=None,
            default_gen_kwargs=default_gen_kwargs,
            developer_message=developer_message,
            string_processors=string_processors,
            model_limit_new_tokens=model_limit_completion_tokens,
        )
        self.model = model
        self.default_gen_kwargs = default_gen_kwargs or {}
        # convert the flexeval-specific argument name to the OpenAI-specific name
        if "max_new_tokens" in self.default_gen_kwargs:
            self.default_gen_kwargs["max_tokens"] = self.default_gen_kwargs.pop("max_new_tokens")

        self.api_call_func = acompletion
        self.empty_response = convert_to_model_response_object(
            response_object=EMPTY_RESPONSE_OPENAI.to_dict(),
            model_response_object=ModelResponse(),
        )
        self.ignore_seed = ignore_seed

    def _batch_complete_text(
        self,
        text_list: list[str],
        stop_sequences: str | list[str] | None = None,
        max_new_tokens: int | None = None,
        **kwargs,
    ) -> list[LMOutput]:
        if "seed" in kwargs and self.ignore_seed:
            kwargs.pop("seed")
        return super()._batch_complete_text(text_list, stop_sequences, max_new_tokens, **kwargs)

    def _batch_generate_chat_response(
        self,
        chat_messages_list: list[list[dict[str, Any]]],
        **kwargs,
    ) -> list[LMOutput]:
        if "seed" in kwargs and self.ignore_seed:
            kwargs.pop("seed")
        return super()._batch_generate_chat_response(chat_messages_list, **kwargs)

    def _batch_compute_chat_log_probs(
        self,
        prompt_list: list[list[dict[str, Any]]],
        response_list: list[dict[str, Any]],
        temperature: float = 0,
        seed: int = 42,
        top_logprobs: int = 20,
    ) -> list[float | None]:
        raise NotImplementedError

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(model={self.model})"

model instance-attribute

model = model

default_gen_kwargs instance-attribute

default_gen_kwargs = default_gen_kwargs or {}

api_call_func instance-attribute

api_call_func = acompletion

empty_response instance-attribute

empty_response = convert_to_model_response_object(
    response_object=to_dict(),
    model_response_object=ModelResponse(),
)

ignore_seed instance-attribute

ignore_seed = ignore_seed

__init__

__init__(
    model: str = "openai/gpt-3.5-turbo",
    default_gen_kwargs: dict[str, Any] | None = None,
    developer_message: str | None = None,
    string_processors: StringProcessor
    | list[StringProcessor]
    | None = None,
    ignore_seed: bool = False,
    model_limit_completion_tokens: int | None = None,
) -> None
Source code in flexeval/core/language_model/litellm_api.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
    self,
    model: str = "openai/gpt-3.5-turbo",
    default_gen_kwargs: dict[str, Any] | None = None,
    developer_message: str | None = None,
    string_processors: StringProcessor | list[StringProcessor] | None = None,
    ignore_seed: bool = False,
    model_limit_completion_tokens: int | None = None,
) -> None:
    super().__init__(
        model=model,
        api_headers=None,
        default_gen_kwargs=default_gen_kwargs,
        developer_message=developer_message,
        string_processors=string_processors,
        model_limit_new_tokens=model_limit_completion_tokens,
    )
    self.model = model
    self.default_gen_kwargs = default_gen_kwargs or {}
    # convert the flexeval-specific argument name to the OpenAI-specific name
    if "max_new_tokens" in self.default_gen_kwargs:
        self.default_gen_kwargs["max_tokens"] = self.default_gen_kwargs.pop("max_new_tokens")

    self.api_call_func = acompletion
    self.empty_response = convert_to_model_response_object(
        response_object=EMPTY_RESPONSE_OPENAI.to_dict(),
        model_response_object=ModelResponse(),
    )
    self.ignore_seed = ignore_seed

__repr__

__repr__() -> str
Source code in flexeval/core/language_model/litellm_api.py
95
96
def __repr__(self) -> str:
    return f"{self.__class__.__name__}(model={self.model})"

OpenAIChatAPI

LanguageModel implementation using OpenAI's ChatGPT API. Note that this class is inherited by litellm_api.LiteLLMChatAPI, so be careful when making any modifications.

Parameters:

  • model (str, default: 'gpt-3.5-turbo' ) –

    The name of the model to use.

  • api_headers (dict[str, str] | None, default: None ) –

    A dictionary of headers to use when making requests to the OpenAI API.

  • default_gen_kwargs (dict[str, Any] | None, default: None ) –

    Default generation kwargs to use when calling the API.

  • developer_message (str | None, default: None ) –

    Instructions to the model that are prioritized ahead of user messages. Previously called the system prompt.

  • string_processors (StringProcessor | list[StringProcessor] | None, default: None ) –

    A single or a list of StringProcessor objects to process the model's output.

  • model_limit_new_tokens (int | None, default: None ) –

    An upper limit on the number of tokens the model can generate. For example, if a too-large max_new_tokens is given to generate_chat_response(), this value will cap it.

Source code in flexeval/core/language_model/openai_api.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
class OpenAIChatAPI(LanguageModel):
    """
    LanguageModel implementation using OpenAI's ChatGPT API.
    Note that this class is inherited by litellm_api.LiteLLMChatAPI, so be careful when making any modifications.

    Args:
        model: The name of the model to use.
        api_headers: A dictionary of headers to use when making requests to the OpenAI API.
        default_gen_kwargs: Default generation kwargs to use when calling the API.
        developer_message: Instructions to the model that are prioritized ahead of user messages.
            Previously called the system prompt.
        string_processors: A single or a list of StringProcessor objects to process the model's output.
        model_limit_new_tokens: An upper limit on the number of tokens the model can generate.
            For example, if a too-large `max_new_tokens` is given to generate_chat_response(), this value will cap it.
    """

    def __init__(
        self,
        model: str = "gpt-3.5-turbo",
        api_headers: dict[str, str] | None = None,
        default_gen_kwargs: dict[str, Any] | None = None,
        developer_message: str | None = None,
        string_processors: StringProcessor | list[StringProcessor] | None = None,
        model_limit_new_tokens: int | None = None,
    ) -> None:
        super().__init__(string_processors=string_processors)
        self.model = model
        if api_headers is None:
            api_headers = {}
        client = AsyncOpenAI(**api_headers)
        self.api_call_func = client.chat.completions.create
        self.empty_response = EMPTY_RESPONSE
        self.default_gen_kwargs = default_gen_kwargs or {}
        # convert the flexeval-specific argument name to the OpenAI-specific name
        if "max_new_tokens" in self.default_gen_kwargs:
            self.default_gen_kwargs["max_completion_tokens"] = self.default_gen_kwargs.pop("max_new_tokens")

        self.developer_message = developer_message
        self.model_limit_new_tokens = model_limit_new_tokens

    async def _async_batch_run_chatgpt(
        self,
        messages_list: list[list[dict[str, Any]]],
        stop_sequences: str | list[str] | None = None,
        max_new_tokens: int | None = None,
        **kwargs,
    ) -> list[str]:
        """Send multiple chat requests to the OpenAI in parallel."""

        if self.developer_message is not None:
            # Insert the developer message at the beginning of each conversation
            messages_list = [
                [{"role": "developer", "content": self.developer_message}, *messages] for messages in messages_list
            ]

        gen_kwargs = self.default_gen_kwargs.copy()
        gen_kwargs.update(kwargs)
        if max_new_tokens is not None:
            if "max_completion_tokens" in gen_kwargs:
                msg = (
                    "You specified both `max_new_tokens` and `max_completion_tokens` in generation kwargs. "
                    "Note that `max_new_tokens` overrides `max_completion_tokens` by default. "
                    "It is recommended to specify only one of them to avoid unexpected behavior."
                )
                logger.warning(msg)
            gen_kwargs["max_completion_tokens"] = max_new_tokens

        if self.model_limit_new_tokens and (gen_kwargs.get("max_completion_tokens", 0) > self.model_limit_new_tokens):
            msg = (
                f"The specified `max_new_tokens` ({gen_kwargs['max_completion_tokens']}) exceeds"
                f"the model’s capability ({self.model_limit_new_tokens} tokens). It will be reduced."
            )
            logger.warning(msg)
            gen_kwargs["max_completion_tokens"] = self.model_limit_new_tokens

        stop_sequences = normalize_stop_sequences(
            stop_sequences_list=[
                stop_sequences,
                gen_kwargs.pop("stop", None),  # This is used in the OpenAI API
                gen_kwargs.pop("stop_sequences", None),  # This is a common variable name used in flexeval
            ],
        )

        tasks = [
            _retry_on_error(
                # Define an anonymous function with a lambda expression and pass it,
                # and call it inside the _retry_on_error function
                openai_call=lambda x=ms: self.api_call_func(
                    model=self.model,
                    messages=x,
                    stop=stop_sequences,
                    **gen_kwargs,
                ),
                empty_response=self.empty_response,
            )
            for ms in messages_list
        ]
        return await asyncio.gather(*tasks)

    def _batch_complete_text(
        self,
        text_list: list[str],
        stop_sequences: str | list[str] | None = None,
        max_new_tokens: int | None = None,
        **kwargs,
    ) -> list[LMOutput]:
        messages_list = [[{"role": "user", "content": text}] for text in text_list]
        api_responses = asyncio.run(
            self._async_batch_run_chatgpt(
                messages_list,
                stop_sequences=stop_sequences,
                max_new_tokens=max_new_tokens,
                **kwargs,
            ),
        )
        outputs = [
            LMOutput(text=res.choices[0].message.content, finish_reason=res.choices[0].finish_reason)
            for res in api_responses
        ]

        if all(output.text == "" for output in outputs):
            logger.warning("All generated texts are empty strings. Something may be wrong.")
        return outputs

    def _batch_generate_chat_response(
        self,
        chat_messages_list: list[list[dict[str, Any]]],
        **kwargs,
    ) -> list[LMOutput]:
        api_responses = asyncio.run(
            self._async_batch_run_chatgpt(chat_messages_list, **kwargs),
        )
        outputs = [
            LMOutput(text=res.choices[0].message.content, finish_reason=res.choices[0].finish_reason)
            for res in api_responses
        ]
        if all(output.text == "" for output in outputs):
            logger.warning("All generated texts are empty strings. Something may go wrong.")
        return outputs

    def _batch_compute_chat_log_probs(
        self,
        prompt_list: list[list[dict[str, Any]]],
        response_list: list[dict[str, Any]],
        temperature: float = 0,
        seed: int = 42,
        top_logprobs: int = 20,
    ) -> list[float | None]:
        """
        Return logprob of one-token response only due to restriction of OpenAI API.
        If you pass a response with two or more tokens, raise an error.

        This function is mainly used for calculating weighted average of multi-choice prompts.
        Under the design of this function, we need to pass the same prompt (the number of choice) times.
        We only need one request for one prompt because OpenAI API returns a list of log probs.
        So, this function removes duplicates of prompts before requesting API and
        returns log probs by restoring the raw prompt list.
        """

        # Check the number of tokens is 1
        response_contents = [resp["content"] for resp in response_list]
        for response_content in response_contents:
            num_tokens = number_of_tokens_in_openai_model(self.model, response_content)
            if num_tokens > 1:
                err_msg = f"OpenAIChatAPI.batch_compute_chat_log_probs is not applicable for two or more tokens of response content: '{response_content}'"  # noqa: E501
                raise NotImplementedError(err_msg)

        # For saving cost, remove duplication from message_list for an API request.
        unique_prompt_list = remove_duplicates_from_prompt_list(prompt_list)
        api_responses = asyncio.run(
            self._async_batch_run_chatgpt(
                unique_prompt_list,
                max_completion_tokens=1,
                seed=seed,
                logprobs=True,
                top_logprobs=top_logprobs,
            ),
        )

        log_probs = []
        top_logprobs_list = [res.choices[0].logprobs.content[0].top_logprobs for res in api_responses]
        for index, prompt in enumerate(prompt_list):
            target_token = response_contents[index]
            index_in_unique = unique_prompt_list.index(prompt)

            log_prob = None  # if target token not in top_logprobs, return None for log_prob of the token
            top_logprobs = top_logprobs_list[index_in_unique]
            for token_logprob in top_logprobs:
                if token_logprob.token == target_token:
                    log_prob = token_logprob.logprob
                    break
            log_probs.append(log_prob)

        return log_probs

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(model={self.model})"

model instance-attribute

model = model

api_call_func instance-attribute

api_call_func = create

empty_response instance-attribute

empty_response = EMPTY_RESPONSE

default_gen_kwargs instance-attribute

default_gen_kwargs = default_gen_kwargs or {}

developer_message instance-attribute

developer_message = developer_message

model_limit_new_tokens instance-attribute

model_limit_new_tokens = model_limit_new_tokens

__init__

__init__(
    model: str = "gpt-3.5-turbo",
    api_headers: dict[str, str] | None = None,
    default_gen_kwargs: dict[str, Any] | None = None,
    developer_message: str | None = None,
    string_processors: StringProcessor
    | list[StringProcessor]
    | None = None,
    model_limit_new_tokens: int | None = None,
) -> None
Source code in flexeval/core/language_model/openai_api.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def __init__(
    self,
    model: str = "gpt-3.5-turbo",
    api_headers: dict[str, str] | None = None,
    default_gen_kwargs: dict[str, Any] | None = None,
    developer_message: str | None = None,
    string_processors: StringProcessor | list[StringProcessor] | None = None,
    model_limit_new_tokens: int | None = None,
) -> None:
    super().__init__(string_processors=string_processors)
    self.model = model
    if api_headers is None:
        api_headers = {}
    client = AsyncOpenAI(**api_headers)
    self.api_call_func = client.chat.completions.create
    self.empty_response = EMPTY_RESPONSE
    self.default_gen_kwargs = default_gen_kwargs or {}
    # convert the flexeval-specific argument name to the OpenAI-specific name
    if "max_new_tokens" in self.default_gen_kwargs:
        self.default_gen_kwargs["max_completion_tokens"] = self.default_gen_kwargs.pop("max_new_tokens")

    self.developer_message = developer_message
    self.model_limit_new_tokens = model_limit_new_tokens

__repr__

__repr__() -> str
Source code in flexeval/core/language_model/openai_api.py
261
262
def __repr__(self) -> str:
    return f"{self.__class__.__name__}(model={self.model})"

OpenAICompletionAPI

LanguageModel implementation using OpenAI's Completion API.

Note that Completion API is a legacy API, with only a few models (such as gpt-3.5-turbo-instruct) supported by OpenAI. This LanguageModel implementation is primarily intended for use with on-premise VLLM servers, as described in the documentation: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html

Parameters:

  • model (str, default: 'gpt-3.5-turbo-instruct' ) –

    The name of the model to use.

  • api_headers (dict[str, str] | None, default: None ) –

    A dictionary of headers to use when making requests to the OpenAI API.

  • default_gen_kwargs (dict[str, Any] | None, default: None ) –

    Default generation kwargs to use when calling the API.

  • string_processors (StringProcessor | list[StringProcessor] | None, default: None ) –

    A single or a list of StringProcessor objects to process the model's output.

Source code in flexeval/core/language_model/openai_api.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
class OpenAICompletionAPI(LanguageModel):
    """LanguageModel implementation using OpenAI's Completion API.

    Note that Completion API is a legacy API, with only a few models (such as gpt-3.5-turbo-instruct)
    supported by OpenAI. This LanguageModel implementation is primarily intended for use with on-premise
    VLLM servers, as described in the documentation: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html

    Args:
        model: The name of the model to use.
        api_headers: A dictionary of headers to use when making requests to the OpenAI API.
        default_gen_kwargs: Default generation kwargs to use when calling the API.
        string_processors: A single or a list of StringProcessor objects to process the model's output.
    """

    def __init__(
        self,
        model: str = "gpt-3.5-turbo-instruct",
        api_headers: dict[str, str] | None = None,
        default_gen_kwargs: dict[str, Any] | None = None,
        string_processors: StringProcessor | list[StringProcessor] | None = None,
    ) -> None:
        super().__init__(string_processors=string_processors)
        self.model = model
        if api_headers is None:
            api_headers = {}
        client = AsyncOpenAI(**api_headers)
        self.api_call_func = client.completions.create
        self.empty_response = EMPTY_RESPONSE
        self.default_gen_kwargs = default_gen_kwargs or {}
        # convert the flexeval-specific argument name to the OpenAI-specific name
        if "max_new_tokens" in self.default_gen_kwargs:
            self.default_gen_kwargs["max_tokens"] = self.default_gen_kwargs.pop("max_new_tokens")

    async def _async_batch_run_completion(
        self,
        prompt_list: list[str],
        stop_sequences: str | list[str] | None = None,
        max_new_tokens: int | None = None,
        **kwargs,
    ) -> list[str]:
        """Send multiple completion requests to the OpenAI in parallel."""

        gen_kwargs = self.default_gen_kwargs.copy()
        gen_kwargs.update(kwargs)
        if max_new_tokens is not None:
            gen_kwargs["max_tokens"] = max_new_tokens

        stop_sequences = normalize_stop_sequences(
            stop_sequences_list=[
                stop_sequences,
                gen_kwargs.pop("stop", None),  # This is used in the OpenAI API
                gen_kwargs.pop("stop_sequences", None),  # This is a common variable name used in flexeval
            ],
        )

        tasks = [
            _retry_on_error(
                # Define an anonymous function with a lambda expression and pass it,
                # and call it inside the _retry_on_error function
                openai_call=lambda x=ms: self.api_call_func(
                    model=self.model,
                    prompt=x,
                    stop=stop_sequences,
                    **gen_kwargs,
                ),
                empty_response=self.empty_response,
            )
            for ms in prompt_list
        ]
        return await asyncio.gather(*tasks)

    def _batch_complete_text(
        self,
        text_list: list[str],
        stop_sequences: str | list[str] | None = None,
        max_new_tokens: int | None = None,
        **kwargs,
    ) -> list[LMOutput]:
        api_responses = asyncio.run(
            self._async_batch_run_completion(
                text_list,
                stop_sequences=stop_sequences,
                max_new_tokens=max_new_tokens,
                **kwargs,
            ),
        )

        return [LMOutput(text=res.choices[0].text, finish_reason=res.choices[0].finish_reason) for res in api_responses]

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(model={self.model})"

model instance-attribute

model = model

api_call_func instance-attribute

api_call_func = create

empty_response instance-attribute

empty_response = EMPTY_RESPONSE

default_gen_kwargs instance-attribute

default_gen_kwargs = default_gen_kwargs or {}

__init__

__init__(
    model: str = "gpt-3.5-turbo-instruct",
    api_headers: dict[str, str] | None = None,
    default_gen_kwargs: dict[str, Any] | None = None,
    string_processors: StringProcessor
    | list[StringProcessor]
    | None = None,
) -> None
Source code in flexeval/core/language_model/openai_api.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
def __init__(
    self,
    model: str = "gpt-3.5-turbo-instruct",
    api_headers: dict[str, str] | None = None,
    default_gen_kwargs: dict[str, Any] | None = None,
    string_processors: StringProcessor | list[StringProcessor] | None = None,
) -> None:
    super().__init__(string_processors=string_processors)
    self.model = model
    if api_headers is None:
        api_headers = {}
    client = AsyncOpenAI(**api_headers)
    self.api_call_func = client.completions.create
    self.empty_response = EMPTY_RESPONSE
    self.default_gen_kwargs = default_gen_kwargs or {}
    # convert the flexeval-specific argument name to the OpenAI-specific name
    if "max_new_tokens" in self.default_gen_kwargs:
        self.default_gen_kwargs["max_tokens"] = self.default_gen_kwargs.pop("max_new_tokens")

__repr__

__repr__() -> str
Source code in flexeval/core/language_model/openai_api.py
391
392
def __repr__(self) -> str:
    return f"{self.__class__.__name__}(model={self.model})"

OpenAIChatBatchAPI

LanguageModel implementation using OpenAI's ChatGPT API for Batch API. NOTE: Batch size should be more than or equal to the size of the given dataset for efficient generation.

Parameters:

  • model (str) –

    The name of the model to use.

  • api_headers (dict[str, str] | None, default: None ) –

    A dictionary of headers to use when making requests to the OpenAI API.

  • polling_interval_seconds (int, default: 60 ) –

    The interval in seconds to poll the batch status.

  • default_gen_kwargs (dict[str, Any] | None, default: None ) –

    Default generation kwargs to use when calling the API.

  • developer_message (str | None, default: None ) –

    Instructions to the model that are prioritized ahead of user messages. Previously called the system prompt.

  • string_processors (StringProcessor | list[StringProcessor] | None, default: None ) –

    A single or a list of StringProcessor objects to process the model's output.

  • model_limit_new_tokens (int | None, default: None ) –

    An upper limit on the number of tokens the model can generate. For example, if a too-large max_new_tokens is given to generate_chat_response(), this value will cap it.

Source code in flexeval/core/language_model/openai_batch_api.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
class OpenAIChatBatchAPI(LanguageModel):
    """LanguageModel implementation using OpenAI's ChatGPT API for Batch API.
    NOTE: Batch size should be more than or equal to the size of the given dataset for efficient generation.

    Args:
        model: The name of the model to use.
        api_headers: A dictionary of headers to use when making requests to the OpenAI API.
        polling_interval_seconds: The interval in seconds to poll the batch status.
        default_gen_kwargs: Default generation kwargs to use when calling the API.
        developer_message: Instructions to the model that are prioritized ahead of user messages.
            Previously called the system prompt.
        string_processors: A single or a list of StringProcessor objects to process the model's output.
        model_limit_new_tokens: An upper limit on the number of tokens the model can generate.
            For example, if a too-large `max_new_tokens` is given to generate_chat_response(), this value will cap it.
    """

    def __init__(
        self,
        model: str,
        api_headers: dict[str, str] | None = None,
        polling_interval_seconds: int = 60,
        default_gen_kwargs: dict[str, Any] | None = None,
        developer_message: str | None = None,
        string_processors: StringProcessor | list[StringProcessor] | None = None,
        model_limit_new_tokens: int | None = None,
    ) -> None:
        super().__init__(string_processors=string_processors)
        self.model = model
        if api_headers is None:
            api_headers = {}
        self._client = AsyncOpenAI(**api_headers)
        self.default_gen_kwargs = default_gen_kwargs or {}
        # convert the flexeval-specific argument name to the OpenAI-specific name
        if "max_new_tokens" in self.default_gen_kwargs:
            self.default_gen_kwargs["max_completion_tokens"] = self.default_gen_kwargs.pop("max_new_tokens")
        self.temp_jsonl_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl")

        self.polling_interval_seconds = polling_interval_seconds
        self.developer_message = developer_message
        self.model_limit_new_tokens = model_limit_new_tokens

    def create_batch_file(self, custom_id_2_message: dict[str, list[dict[str, Any]]], **kwargs) -> None:
        with open(self.temp_jsonl_file.name, mode="w") as f:
            for custom_id, message in custom_id_2_message.items():
                if self.developer_message:
                    message = [{"role": "developer", "content": self.developer_message}, *message]  # noqa: PLW2901

                f.write(
                    json.dumps(create_request_details(self.model, custom_id, message, **kwargs), ensure_ascii=False)
                    + "\n",
                )

    async def _post_batch_requests(
        self,
        custom_id_2_message: dict[str, list[dict[str, Any]]],
        stop_sequences: str | list[str] | None = None,
        max_new_tokens: int | None = None,
        **kwargs,
    ) -> str:
        gen_kwargs = self.default_gen_kwargs.copy()
        gen_kwargs.update(kwargs)
        """Send batch chat requests to the OpenAI."""

        if max_new_tokens is not None:
            if "max_completion_tokens" in gen_kwargs:
                msg = (
                    "You specified both `max_new_tokens` and `max_completion_tokens` in generation kwargs. "
                    "Note that `max_new_tokens` overrides `max_completion_tokens` by default. "
                    "It is recommended to specify only one of them to avoid unexpected behavior."
                )
                logger.warning(msg)
            gen_kwargs["max_completion_tokens"] = max_new_tokens

        if self.model_limit_new_tokens and (gen_kwargs.get("max_completion_tokens", 0) > self.model_limit_new_tokens):
            msg = (
                f"The specified `max_new_tokens` ({gen_kwargs['max_completion_tokens']}) exceeds"
                f"the model’s capability ({self.model_limit_new_tokens} tokens). It will be reduced."
            )
            logger.warning(msg)
            gen_kwargs["max_completion_tokens"] = self.model_limit_new_tokens

        gen_kwargs["stop"] = normalize_stop_sequences(
            stop_sequences_list=[
                stop_sequences,
                gen_kwargs.pop("stop", None),  # This is used in the OpenAI API
                gen_kwargs.pop("stop_sequences", None),  # This is a common variable name used in flexeval
            ],
        )

        self.create_batch_file(custom_id_2_message, **gen_kwargs)

        # Update batch file
        with open(self.temp_jsonl_file.name, "rb") as batch_file:  # noqa: ASYNC101
            batch_input_file = await self._client.files.create(file=batch_file, purpose="batch")

        # Run Job
        # Batch Object: https://platform.openai.com/docs/api-reference/batch/object
        batch_object = await self._client.batches.create(
            input_file_id=batch_input_file.id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
            metadata={"description": "flexeval job"},
        )
        logger.info(f"Input File ID: {batch_input_file.id}, Batch ID: {batch_object.id}")
        return batch_object.id

    async def poll_batch_status_until_completion(
        self,
        batch_id: str,
        polling_interval_seconds: int,
    ) -> tuple[Status, Batch]:
        status = Status.validating
        while status not in (Status.completed, Status.failed, Status.canceled):
            await asyncio.sleep(polling_interval_seconds)
            batch_response = await self._client.batches.retrieve(batch_id)
            status = Status(batch_response.status)
            logger.info(f"Current status: {status.value}")
        return status, batch_response

    def _retrieve_file_content(self, file_id: str) -> list[dict[any, any]]:
        file_response = asyncio.run(self._client.files.content(file_id))
        return [json.loads(line) for line in file_response.text.strip().split("\n")]

    def _execute_batch_requests(
        self,
        messages_list: list[list[dict[str, Any]]],
        **kwargs,
    ) -> list[Any]:
        custom_id_2_message: dict[str, list[dict[str, Any]]] = {
            str(uuid.uuid4()): messages for messages in messages_list
        }
        # The response will be an empty string if the API produces an error.
        custom_id_2_response: dict[str, str | list[dict[str, Any]]] = {
            custom_id: "" for custom_id in custom_id_2_message
        }
        exec_cnt = 1

        while len(custom_id_2_message) > 0:
            if exec_cnt > MAX_NUM_TRIALS:
                break
            logger.info(f"Trial {exec_cnt}")
            exec_cnt += 1
            batch_id = asyncio.run(self._post_batch_requests(custom_id_2_message, **kwargs))

            status, batch_response = asyncio.run(
                self.poll_batch_status_until_completion(batch_id, self.polling_interval_seconds),
            )
            if status is not Status.completed:
                error_message = f"Failed: {batch_response}"
                raise ValueError(error_message)

            # Check error_file_id exists and if exists, log error details.
            error_file_id = batch_response.error_file_id
            # If any request fails, error_file_id is set.
            if error_file_id is not None:
                logger.warning("Request on some messages failed following reason.")
                data: list[dict[str, Any]] = self._retrieve_file_content(error_file_id)
                # [Error](https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8857])
                # instance is embedded in response.
                for data_i in data:
                    error = data_i["response"]
                    logger.warning(f"Failed: {error}")

            output_file_id = batch_response.output_file_id
            # If completion on all input fails, output_file_id is None.
            if output_file_id is None:
                logger.warning("All request failed. Continue...")
                continue

            data: list[dict[str, Any]] = self._retrieve_file_content(output_file_id)
            for data_i in data:
                if data_i["error"] is not None:
                    continue

                custom_id = data_i["custom_id"]
                custom_id_2_message.pop(custom_id)

                custom_id_2_response[custom_id] = data_i["response"]["body"]

        # The remaining elements are all those that failed to complete request.
        if custom_id_2_message:
            logger.warning("The following messages failed to complete request.")
            logger.warning(pformat(list(custom_id_2_message.values())))

        return list(custom_id_2_response.values())

    def _batch_complete_text(
        self,
        text_list: list[str],
        stop_sequences: str | list[str] | None = None,
        max_new_tokens: int | None = None,
        **kwargs,
    ) -> list[LMOutput]:
        messages_list = [[{"role": "user", "content": text}] for text in text_list]
        api_responses = self._execute_batch_requests(
            messages_list,
            stop_sequences=stop_sequences,
            max_new_tokens=max_new_tokens,
            **kwargs,
        )
        return [
            LMOutput(text=res["choices"][0]["message"]["content"], finish_reason=res["choices"][0]["finish_reason"])
            for res in api_responses
        ]

    def _batch_generate_chat_response(
        self,
        chat_messages_list: list[list[dict[str, Any]]],
        **kwargs,
    ) -> list[LMOutput]:
        api_responses = self._execute_batch_requests(
            chat_messages_list,
            **kwargs,
        )
        return [
            LMOutput(text=res["choices"][0]["message"]["content"], finish_reason=res["choices"][0]["finish_reason"])
            for res in api_responses
        ]

    def close(self) -> None:
        # in case that the program fails before the file is initialized in __init__
        if not hasattr(self, "temp_jsonl_file"):
            return

        try:
            self.temp_jsonl_file.close()
            os.unlink(self.temp_jsonl_file.name)  # noqa: PTH108
            logger.info(f"Temporary file deleted: {self.temp_jsonl_file.name}")
        except OSError as e:
            logger.error(f"Error: {e.filename} - {e.strerror}.")

    def _batch_compute_chat_log_probs(
        self,
        prompt_list: list[list[dict[str, Any]]],
        response_list: list[dict[str, Any]],
        temperature: float = 0,
        seed: int = 42,
        top_logprobs: int = 20,
    ) -> list[float]:
        response_contents = [resp["content"] for resp in response_list]
        for response_content in response_contents:
            num_tokens = number_of_tokens_in_openai_model(self.model, response_content)
            if num_tokens > 1:
                err_msg = f"OpenAIChatAPI.batch_compute_chat_log_probs is not applicable for two or more tokens of response content: '{response_content}'"  # noqa: E501
                raise NotImplementedError(err_msg)

        # For saving cost, remove duplication from message_list for an API request.
        unique_prompt_list = remove_duplicates_from_prompt_list(prompt_list)
        api_responses = self._execute_batch_requests(
            unique_prompt_list,
            max_new_tokens=1,
            seed=seed,
            logprobs=True,
            top_logprobs=top_logprobs,
        )

        log_probs = []
        top_logprobs_list = [res["choices"][0]["logprobs"]["content"][0]["top_logprobs"] for res in api_responses]
        for index, prompt in enumerate(prompt_list):
            target_token = response_contents[index]
            index_in_unique = unique_prompt_list.index(prompt)

            log_prob = None  # if target token not in top_logprobs, return None for log_prob of the token
            top_logprobs = top_logprobs_list[index_in_unique]
            for token_logprob in top_logprobs:
                if token_logprob["token"] == target_token:
                    log_prob = token_logprob["logprob"]
                    break
            log_probs.append(log_prob)

        return log_probs

    def __del__(self) -> None:
        self.close()

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(model={self.model})"

model instance-attribute

model = model

default_gen_kwargs instance-attribute

default_gen_kwargs = default_gen_kwargs or {}

temp_jsonl_file instance-attribute

temp_jsonl_file = NamedTemporaryFile(
    delete=False, suffix=".jsonl"
)

polling_interval_seconds instance-attribute

polling_interval_seconds = polling_interval_seconds

developer_message instance-attribute

developer_message = developer_message

model_limit_new_tokens instance-attribute

model_limit_new_tokens = model_limit_new_tokens

__init__

__init__(
    model: str,
    api_headers: dict[str, str] | None = None,
    polling_interval_seconds: int = 60,
    default_gen_kwargs: dict[str, Any] | None = None,
    developer_message: str | None = None,
    string_processors: StringProcessor
    | list[StringProcessor]
    | None = None,
    model_limit_new_tokens: int | None = None,
) -> None
Source code in flexeval/core/language_model/openai_batch_api.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def __init__(
    self,
    model: str,
    api_headers: dict[str, str] | None = None,
    polling_interval_seconds: int = 60,
    default_gen_kwargs: dict[str, Any] | None = None,
    developer_message: str | None = None,
    string_processors: StringProcessor | list[StringProcessor] | None = None,
    model_limit_new_tokens: int | None = None,
) -> None:
    super().__init__(string_processors=string_processors)
    self.model = model
    if api_headers is None:
        api_headers = {}
    self._client = AsyncOpenAI(**api_headers)
    self.default_gen_kwargs = default_gen_kwargs or {}
    # convert the flexeval-specific argument name to the OpenAI-specific name
    if "max_new_tokens" in self.default_gen_kwargs:
        self.default_gen_kwargs["max_completion_tokens"] = self.default_gen_kwargs.pop("max_new_tokens")
    self.temp_jsonl_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl")

    self.polling_interval_seconds = polling_interval_seconds
    self.developer_message = developer_message
    self.model_limit_new_tokens = model_limit_new_tokens

create_batch_file

create_batch_file(
    custom_id_2_message: dict[str, list[dict[str, Any]]],
    **kwargs,
) -> None
Source code in flexeval/core/language_model/openai_batch_api.py
86
87
88
89
90
91
92
93
94
95
def create_batch_file(self, custom_id_2_message: dict[str, list[dict[str, Any]]], **kwargs) -> None:
    with open(self.temp_jsonl_file.name, mode="w") as f:
        for custom_id, message in custom_id_2_message.items():
            if self.developer_message:
                message = [{"role": "developer", "content": self.developer_message}, *message]  # noqa: PLW2901

            f.write(
                json.dumps(create_request_details(self.model, custom_id, message, **kwargs), ensure_ascii=False)
                + "\n",
            )

poll_batch_status_until_completion async

poll_batch_status_until_completion(
    batch_id: str, polling_interval_seconds: int
) -> tuple[Status, Batch]
Source code in flexeval/core/language_model/openai_batch_api.py
151
152
153
154
155
156
157
158
159
160
161
162
async def poll_batch_status_until_completion(
    self,
    batch_id: str,
    polling_interval_seconds: int,
) -> tuple[Status, Batch]:
    status = Status.validating
    while status not in (Status.completed, Status.failed, Status.canceled):
        await asyncio.sleep(polling_interval_seconds)
        batch_response = await self._client.batches.retrieve(batch_id)
        status = Status(batch_response.status)
        logger.info(f"Current status: {status.value}")
    return status, batch_response

close

close() -> None
Source code in flexeval/core/language_model/openai_batch_api.py
264
265
266
267
268
269
270
271
272
273
274
def close(self) -> None:
    # in case that the program fails before the file is initialized in __init__
    if not hasattr(self, "temp_jsonl_file"):
        return

    try:
        self.temp_jsonl_file.close()
        os.unlink(self.temp_jsonl_file.name)  # noqa: PTH108
        logger.info(f"Temporary file deleted: {self.temp_jsonl_file.name}")
    except OSError as e:
        logger.error(f"Error: {e.filename} - {e.strerror}.")

__del__

__del__() -> None
Source code in flexeval/core/language_model/openai_batch_api.py
317
318
def __del__(self) -> None:
    self.close()

__repr__

__repr__() -> str
Source code in flexeval/core/language_model/openai_batch_api.py
320
321
def __repr__(self) -> str:
    return f"{self.__class__.__name__}(model={self.model})"

VLLM

LanguageModel implementation using VLLM.

Parameters:

  • model (str) –

    The name of the model to use.

  • model_kwargs (dict[str, Any] | None, default: None ) –

    Additional keyword arguments to pass to the model.

  • tokenizer (str | None, default: None ) –

    The name of the tokenizer to use. Defaults to the model_name.

  • tokenizer_kwargs (dict[str, Any] | None, default: None ) –

    Keyword arguments for the tokenizer instantiation by `from_pretrained().

  • add_special_tokens (bool, default: False ) –

    Whether to add special tokens to the input. Note that whether BOS or EOS tokens are added depends on the tokenizer.

  • custom_chat_template (str | None, default: None ) –

    A custom chat template for chatbot models. If specified, this overrides the default chat template of the tokenizer.

  • default_gen_kwargs (dict[str, Any] | None, default: None ) –

    Default generation kwargs to use when calling the model.

  • string_processors (StringProcessor | list[StringProcessor] | None, default: None ) –

    A single or a list of StringProcessor objects to process the model's output.

  • model_limit_tokens (int | None | Literal['default'], default: 'default' ) –

    An upper limit on the number of tokens (input + output) the model can handle. If max_new_tokens exceeds this limit in generate_chat_response(), it will be capped to this value. If this value is set to less than or equal to the model's capacity and the input exceeds it, an empty string is returned instead of raising an error. If set to “default”, the value will be automatically determined when possible.

Source code in flexeval/core/language_model/vllm_model.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
class VLLM(LanguageModel):
    """LanguageModel implementation using VLLM.

    Args:
        model: The name of the model to use.
        model_kwargs: Additional keyword arguments to pass to the model.
        tokenizer: The name of the tokenizer to use. Defaults to the model_name.
        tokenizer_kwargs: Keyword arguments for the tokenizer instantiation by `from_pretrained().
        add_special_tokens: Whether to add special tokens to the input.
            Note that whether BOS or EOS tokens are added depends on the tokenizer.
        custom_chat_template: A custom chat template for chatbot models.
            If specified, this overrides the default chat template of the tokenizer.
        default_gen_kwargs: Default generation kwargs to use when calling the model.
        string_processors: A single or a list of StringProcessor objects to process the model's output.
        model_limit_tokens: An upper limit on the number of tokens (input + output) the model can handle.
            If `max_new_tokens` exceeds this limit in `generate_chat_response()`, it will be capped to this value.
            If this value is set to less than or equal to the model's capacity and the input exceeds it,
            an empty string is returned instead of raising an error.
            If set to “default”, the value will be automatically determined when possible.
    """

    def __init__(
        self,
        model: str,
        model_kwargs: dict[str, Any] | None = None,
        tokenizer: str | None = None,
        tokenizer_kwargs: dict[str, Any] | None = None,
        add_special_tokens: bool = False,
        custom_chat_template: str | None = None,
        default_gen_kwargs: dict[str, Any] | None = None,
        string_processors: StringProcessor | list[StringProcessor] | None = None,
        model_limit_tokens: int | None | Literal["default"] = "default",
    ) -> None:
        super().__init__(string_processors=string_processors)
        self.model_name = model
        tokenizer = tokenizer if tokenizer else model
        tokenizer_kwargs = tokenizer_kwargs or {}
        self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs)
        self.custom_chat_template = custom_chat_template
        self.add_special_tokens = add_special_tokens
        # use greedy decoding by default to make it consistent with `HuggingFaceLM`
        self.default_gen_kwargs = default_gen_kwargs or {"temperature": 0.0}
        # convert the flexeval-specific argument name to the vllm-specific name
        if "max_new_tokens" in self.default_gen_kwargs:
            self.default_gen_kwargs["max_tokens"] = self.default_gen_kwargs.pop("max_new_tokens")

        # import from vllm here because it is an extra dependency
        from vllm import LLM

        model_kwargs = model_kwargs or {}
        # automatically set tensor_parallel_size to the number of GPUs
        if "tensor_parallel_size" not in model_kwargs:
            model_kwargs["tensor_parallel_size"] = torch.cuda.device_count()
        if "enable_chunked_prefill" not in model_kwargs:
            model_kwargs["enable_chunked_prefill"] = True
            model_kwargs["disable_sliding_window"] = True
        self.llm = LLM(model, **model_kwargs)

        if model_limit_tokens == "default":
            model_limit_tokens = self.llm.llm_engine.get_model_config().max_model_len
        self.model_limit_tokens = model_limit_tokens

    def _batch_complete_text(
        self,
        text_list: list[str],
        stop_sequences: str | list[str] | None = None,
        max_new_tokens: int | None = None,
        **kwargs,
    ) -> list[LMOutput]:
        gen_kwargs = self.default_gen_kwargs.copy()
        gen_kwargs.update(kwargs)
        if max_new_tokens is not None:
            gen_kwargs["max_tokens"] = max_new_tokens

        stop_sequences = normalize_stop_sequences(
            stop_sequences_list=[
                stop_sequences,
                gen_kwargs.pop("stop", None),  # This is used in the vllm `SamplingParams`
                gen_kwargs.pop("stop_sequences", None),  # This is a common variable name used in flexeval
            ],
            bos_token=self.tokenizer.bos_token,
            eos_token=self.tokenizer.eos_token,
            ignore_eos=gen_kwargs.get("ignore_eos", False),
        )

        model_inputs = self.tokenizer(
            text_list,
            add_special_tokens=self.add_special_tokens,
            return_token_type_ids=False,
        )

        from vllm import RequestOutput, SamplingParams

        prompt_token_ids: list[list[int]] = model_inputs.input_ids
        sampling_params: list[SamplingParams] = []
        skip_flag_list: list[bool] = []
        for i, input_ids in enumerate(model_inputs.input_ids):
            remaining = self.model_limit_tokens - len(input_ids)
            instance_gen_kwargs = gen_kwargs.copy()
            if remaining <= 0:
                prompt_token_ids[i] = input_ids[:1]
                instance_gen_kwargs["max_tokens"] = 1
                msg = (
                    f"Received input that is longer than `model_limit_tokens = {self.model_limit_tokens}`. "
                    f"The instane returns empty strings."
                )
                logger.warning(msg)
            elif "max_tokens" not in gen_kwargs or remaining < gen_kwargs["max_tokens"]:
                instance_gen_kwargs["max_tokens"] = remaining
            sampling_params.append(SamplingParams(**instance_gen_kwargs, stop=stop_sequences))
            skip_flag_list.append(remaining <= 0)

        vllm_outputs: list[RequestOutput] = self.llm.generate(
            prompt_token_ids=prompt_token_ids,
            sampling_params=sampling_params,
            use_tqdm=False,
        )

        outputs = []
        for input_token_ids, vllm_output, skip_flag in zip(model_inputs.input_ids, vllm_outputs, skip_flag_list):
            if skip_flag:
                # Treat skipped instances as if they generated an empty string.
                decoded_text = ""
                finish_reason = "input_length_limit"
            else:
                output_token_ids = list(vllm_output.outputs[0].token_ids)
                decoded_text = decode_for_lm_continuation(output_token_ids, input_token_ids, self.tokenizer)
                finish_reason = "length"
            # We manually remove the stop sequences from the generated texts.
            for stop in stop_sequences:
                stop_index = decoded_text.find(stop)
                if stop_index != -1:
                    decoded_text = decoded_text[:stop_index]
                    finish_reason = "stop"

            outputs.append(LMOutput(text=decoded_text, finish_reason=finish_reason))
        return outputs

    def _batch_generate_chat_response(
        self,
        chat_messages_list: list[list[dict[str, Any]]],
        **kwargs,
    ) -> list[LMOutput]:
        chat_messages_as_string = [
            self.tokenizer.apply_chat_template(
                chat_messages,
                tokenize=False,
                add_generation_prompt=True,
                chat_template=self.custom_chat_template,
            )
            for chat_messages in chat_messages_list
        ]
        return self._batch_complete_text(chat_messages_as_string, **kwargs)

    def _batch_compute_log_probs(
        self, text_list: list[str], prefix_list: list[str] | None = None, stride: int | None = None
    ) -> list[float]:
        batch_size = len(text_list)

        # prepare prefix encoding
        prefix_list = prefix_list if prefix_list else ["" for _ in range(batch_size)]
        # If the prefix is an empty string, replace it with the bos token regardless of the model being trained with it.
        # This is needed to correctly calculate the log probabilities of the first token.
        for i in range(batch_size):
            if prefix_list[i] == "":
                prefix_list[i] = self.tokenizer.bos_token

        batch_prefix_ids = tokenize_text_for_lm_prefix(
            prefix_list,
            self.tokenizer,
            add_special_tokens=self.add_special_tokens,
        )

        # prepare continuation encoding
        # If the last token is a special token, it is treated as a beginning of a new sentence.
        batch_continuation_ids = tokenize_text_for_lm_continuation(
            text_list,
            self.tokenizer,
            as_continuation=[prefix_ids[-1] not in self.tokenizer.all_special_ids for prefix_ids in batch_prefix_ids],
        )

        batch_input_ids = [
            prefix + continuation for prefix, continuation in zip(batch_prefix_ids, batch_continuation_ids)
        ]

        max_length = self.llm.llm_engine.get_model_config().max_seq_len_to_capture
        stride = stride or max_length // 2
        if not (0 < stride < max_length):
            msg = f"stride must be in (0, {max_length}), but got {stride}"
            raise ValueError(msg)
        sequence_length = max([len(input_ids) for input_ids in batch_input_ids])

        from vllm import RequestOutput, SamplingParams
        from vllm.sequence import Logprob

        sampling_params = SamplingParams(temperature=0.0, max_tokens=1, prompt_logprobs=1)

        batch_logprobs = [0.0] * batch_size
        last_computed_index = 0
        for chunk_start in range(0, sequence_length, stride):
            chunk_end = min(chunk_start + max_length, sequence_length)
            chunk_batch_input_ids = [input_ids[chunk_start:chunk_end] for input_ids in batch_input_ids]
            chunk_batch_input_ids = [
                [self.tokenizer.bos_token_id] if len(chunk_input_ids) == 0 else chunk_input_ids
                for chunk_input_ids in chunk_batch_input_ids
            ]
            chunk_batch_outputs: list[RequestOutput] = self.llm.generate(
                prompt_token_ids=chunk_batch_input_ids,
                sampling_params=sampling_params,
                use_tqdm=False,
            )

            i = 0
            for ids, output, prefix_ids in zip(chunk_batch_input_ids, chunk_batch_outputs, batch_prefix_ids):
                chunk_rest_prefix_length = max(len(prefix_ids) - last_computed_index, 0)
                chunk_continuation_start = last_computed_index - chunk_start + chunk_rest_prefix_length

                # `prompt_logprobs` has the same length as the input `ids`.
                # The i-th element contains the log probabilities of the i-th token in `ids`
                # and the highest-likelihood token at that position.
                # The 0-th element is always `None` because the log probability cannot be computed for it.
                prompt_logprobs: list[dict[int, Logprob] | None] = output.prompt_logprobs
                all_token_logprobs = [
                    cands[token_id].logprob if cands else 0.0 for cands, token_id in zip(prompt_logprobs, ids)
                ]
                continuation_logprob = float(sum(all_token_logprobs[chunk_continuation_start:]))
                batch_logprobs[i] += continuation_logprob
                i += 1

            last_computed_index = chunk_end

        return batch_logprobs

    def _batch_compute_chat_log_probs(
        self, prompt_list: list[list[dict[str, Any]]], response_list: list[dict[str, Any]]
    ) -> list[float]:
        prompt_as_string: list[str] = []
        response_as_string: list[str] = []
        for prompt, response in zip(prompt_list, response_list):
            prompt_as_string_i, response_as_string_i = get_prefix_and_completion_from_chat(
                prompt,
                response,
                self.tokenizer,
                custom_chat_template=self.custom_chat_template,
            )
            prompt_as_string.append(prompt_as_string_i)
            response_as_string.append(response_as_string_i)
        return self._batch_compute_log_probs(response_as_string, prefix_list=prompt_as_string)

    def __repr__(self) -> str:
        return f"VLLM(model_name={self.model_name})"

model_name instance-attribute

model_name = model

tokenizer instance-attribute

tokenizer: PreTrainedTokenizer = from_pretrained(
    tokenizer, **tokenizer_kwargs
)

custom_chat_template instance-attribute

custom_chat_template = custom_chat_template

add_special_tokens instance-attribute

add_special_tokens = add_special_tokens

default_gen_kwargs instance-attribute

default_gen_kwargs = default_gen_kwargs or {
    "temperature": 0.0
}

llm instance-attribute

llm = LLM(model, **model_kwargs)

model_limit_tokens instance-attribute

model_limit_tokens = model_limit_tokens

__init__

__init__(
    model: str,
    model_kwargs: dict[str, Any] | None = None,
    tokenizer: str | None = None,
    tokenizer_kwargs: dict[str, Any] | None = None,
    add_special_tokens: bool = False,
    custom_chat_template: str | None = None,
    default_gen_kwargs: dict[str, Any] | None = None,
    string_processors: StringProcessor
    | list[StringProcessor]
    | None = None,
    model_limit_tokens: int
    | None
    | Literal["default"] = "default",
) -> None
Source code in flexeval/core/language_model/vllm_model.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def __init__(
    self,
    model: str,
    model_kwargs: dict[str, Any] | None = None,
    tokenizer: str | None = None,
    tokenizer_kwargs: dict[str, Any] | None = None,
    add_special_tokens: bool = False,
    custom_chat_template: str | None = None,
    default_gen_kwargs: dict[str, Any] | None = None,
    string_processors: StringProcessor | list[StringProcessor] | None = None,
    model_limit_tokens: int | None | Literal["default"] = "default",
) -> None:
    super().__init__(string_processors=string_processors)
    self.model_name = model
    tokenizer = tokenizer if tokenizer else model
    tokenizer_kwargs = tokenizer_kwargs or {}
    self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs)
    self.custom_chat_template = custom_chat_template
    self.add_special_tokens = add_special_tokens
    # use greedy decoding by default to make it consistent with `HuggingFaceLM`
    self.default_gen_kwargs = default_gen_kwargs or {"temperature": 0.0}
    # convert the flexeval-specific argument name to the vllm-specific name
    if "max_new_tokens" in self.default_gen_kwargs:
        self.default_gen_kwargs["max_tokens"] = self.default_gen_kwargs.pop("max_new_tokens")

    # import from vllm here because it is an extra dependency
    from vllm import LLM

    model_kwargs = model_kwargs or {}
    # automatically set tensor_parallel_size to the number of GPUs
    if "tensor_parallel_size" not in model_kwargs:
        model_kwargs["tensor_parallel_size"] = torch.cuda.device_count()
    if "enable_chunked_prefill" not in model_kwargs:
        model_kwargs["enable_chunked_prefill"] = True
        model_kwargs["disable_sliding_window"] = True
    self.llm = LLM(model, **model_kwargs)

    if model_limit_tokens == "default":
        model_limit_tokens = self.llm.llm_engine.get_model_config().max_model_len
    self.model_limit_tokens = model_limit_tokens

__repr__

__repr__() -> str
Source code in flexeval/core/language_model/vllm_model.py
319
320
def __repr__(self) -> str:
    return f"VLLM(model_name={self.model_name})"