Source code for bundled_program.config
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from dataclasses import dataclass
from typing import get_args, List, Optional, Sequence, Union
import torch
from torch.utils._pytree import tree_flatten
from typing_extensions import TypeAlias
"""
The data types currently supported for element to be bundled. It should be
consistent with the types in bundled_program.schema.Value.
"""
ConfigValue: TypeAlias = Union[
torch.Tensor,
int,
bool,
float,
]
"""
The data type of the input for method single execution.
"""
MethodInputType: TypeAlias = Sequence[ConfigValue]
"""
The data type of the output for method single execution.
"""
MethodOutputType: TypeAlias = Sequence[torch.Tensor]
"""
All supported types for input/expected output of MethodTestCase.
Namedtuple is also supported and listed implicity since it is a subclass of tuple.
"""
# pyre-ignore
DataContainer: TypeAlias = Union[list, tuple, dict]
class MethodTestCase:
"""Test case with inputs and expected outputs
The expected_outputs are optional and only required if the user wants to verify model outputs after execution."""
def __init__(
self,
inputs: MethodInputType,
expected_outputs: Optional[MethodOutputType] = None,
) -> None:
"""Single test case for verifying specific method
Args:
input: All inputs required by eager_model with specific inference method for one-time execution.
It is worth mentioning that, although both bundled program and ET runtime apis support setting input
other than torch.tensor type, only the input in torch.tensor type will be actually updated in
the method, and the rest of the inputs will just do a sanity check if they match the default value in method.
expected_output: Expected output of given input for verification. It can be None if user only wants to use the test case for profiling.
Returns:
self
"""
# TODO(gasoonjia): Update type check logic.
# pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sannity check.
self.inputs: List[ConfigValue] = self._flatten_and_sanity_check(inputs)
self.expected_outputs: List[ConfigValue] = []
if expected_outputs is not None:
# pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sannity check.
self.expected_outputs = self._flatten_and_sanity_check(expected_outputs)
def _flatten_and_sanity_check(
self, unflatten_data: DataContainer
) -> List[ConfigValue]:
"""Flat the given data and check its legality
Args:
unflatten_data: Data needs to be flatten.
Returns:
flatten_data: Flatten data with legal type.
"""
flatten_data, _ = tree_flatten(unflatten_data)
for data in flatten_data:
assert isinstance(
data,
get_args(ConfigValue),
), "The type of input {} with type {} is not supported.\n".format(
data, type(data)
)
assert not isinstance(
data,
type(None),
), "The input {} should not be in null type.\n".format(data)
return flatten_data
[docs]@dataclass
class MethodTestSuite:
"""All test info related to verify method
Attributes:
method_name: Name of the method to be verified.
test_cases: All test cases for verifying the method.
"""
method_name: str
test_cases: Sequence[MethodTestCase]