diff --git a/api/schemas.py b/api/schemas.py index 6e59bdafd..c5d6945db 100644 --- a/api/schemas.py +++ b/api/schemas.py @@ -742,12 +742,6 @@ class MetricType(str, Enum): table = "table" -# class MetricTypeTableGroupField(str, Enum): -# users = "USERS" -# click_rage = IssueType.click_rage -# dead_click = IssueType.dead_click - - class CustomMetricRawPayloadSchema(BaseModel): startDate: int = Field(TimeUTC.now(-7)) endDate: int = Field(TimeUTC.now()) @@ -760,7 +754,7 @@ class CustomMetricRawPayloadSchema2(CustomMetricRawPayloadSchema): metric_id: int = Field(...) -class MetricOfType(str, Enum): +class TableMetricOfType(str, Enum): user_os = FilterType.user_os.value user_browser = FilterType.user_browser.value user_device = FilterType.user_device.value @@ -770,14 +764,25 @@ class MetricOfType(str, Enum): visited_url = "VISITED_URL" +class TimeseriesMetricOfType(str, Enum): + session_count = "sessionCount" + + class CustomMetricChartPayloadSchema(CustomMetricRawPayloadSchema): startDate: int = Field(TimeUTC.now(-7)) endDate: int = Field(TimeUTC.now()) density: int = Field(7) viewType: MetricViewType = Field(MetricViewType.line_chart) metricType: MetricType = Field(MetricType.timeseries) - metricOf: MetricOfType = Field(MetricOfType.user_id) - metricFraction: float = Field(None, gt=0, lt=1) + metricOf: Union[TableMetricOfType, TimeseriesMetricOfType] = Field(TableMetricOfType.user_id) + + # metricFraction: float = Field(None, gt=0, lt=1) + @root_validator + def validator(cls, values): + if isinstance(values.get("metricOf"), TimeseriesMetricOfType): + assert values.get("metricType") == MetricType.timeseries, \ + f"Only metricType:{MetricType.timeseries.value} is allowed for metricOf: {values.get('metricOf')}" + return values class CustomMetricChartPayloadSchema2(CustomMetricChartPayloadSchema):