diff --git a/README.md b/README.md index c17a596..0743342 100644 --- a/README.md +++ b/README.md @@ -58,8 +58,8 @@ type Greeting struct { Hello string `json:"hello"` } -func greet(r *http.Request, body *Person) (*Greeting, error) { - return &Greeting{Hello: body.Name}, nil +func greet(r *http.Request, in *Person) (*Greeting, error) { + return &Greeting{Hello: in.Name}, nil } func main() { @@ -82,20 +82,51 @@ That's it. ShiftAPI reflects your Go types into an OpenAPI 3.1 spec at `/openapi ### Generic type-safe handlers -Generic free functions capture your request and response types at compile time. Handlers with a body (`Post`, `Put`, `Patch`) receive the decoded request as a typed value. Handlers without a body (`Get`, `Delete`, `Head`) just receive the request. +Generic free functions capture your request and response types at compile time. Every method uses a single function — struct tags discriminate query params (`query:"..."`) from body fields (`json:"..."`). For routes without input, use `_ struct{}`. ```go -// POST — body is decoded and passed as *CreateUser -shiftapi.Post(api, "/users", func(r *http.Request, body *CreateUser) (*User, error) { - return db.CreateUser(r.Context(), body) +// POST with body — input is decoded and passed as *CreateUser +shiftapi.Post(api, "/users", func(r *http.Request, in *CreateUser) (*User, error) { + return db.CreateUser(r.Context(), in) }, shiftapi.WithStatus(http.StatusCreated)) -// GET — standard *http.Request, use PathValue for path params -shiftapi.Get(api, "/users/{id}", func(r *http.Request) (*User, error) { +// GET without input — use _ struct{} +shiftapi.Get(api, "/users/{id}", func(r *http.Request, _ struct{}) (*User, error) { return db.GetUser(r.Context(), r.PathValue("id")) }) ``` +### Typed query parameters + +Define a struct with `query` tags. Query params are parsed, validated, and documented in the OpenAPI spec automatically. + +```go +type SearchQuery struct { + Q string `query:"q" validate:"required"` + Page int `query:"page" validate:"min=1"` + Limit int `query:"limit" validate:"min=1,max=100"` +} + +shiftapi.Get(api, "/search", func(r *http.Request, in SearchQuery) (*Results, error) { + return doSearch(in.Q, in.Page, in.Limit), nil +}) +``` + +Supports `string`, `bool`, `int*`, `uint*`, `float*` scalars, `*T` pointers for optional params, and `[]T` slices for repeated params (e.g. `?tag=a&tag=b`). Parse errors return `400`; validation failures return `422`. + +For handlers that need both query parameters and a request body, combine them in a single struct — fields with `query` tags become query params, fields with `json` tags become the body: + +```go +type CreateInput struct { + DryRun bool `query:"dry_run"` + Name string `json:"name"` +} + +shiftapi.Post(api, "/items", func(r *http.Request, in CreateInput) (*Result, error) { + return createItem(in.Name, in.DryRun), nil +}) +``` + ### Validation Built-in validation via [go-playground/validator](https://github.com/go-playground/validator). Struct tags are enforced at runtime *and* reflected into the OpenAPI schema. @@ -198,6 +229,11 @@ const { data: greeting } = await client.POST("/greet", { body: { name: "frank" }, }); // body and response are fully typed from your Go structs + +const { data: results } = await client.GET("/search", { + params: { query: { q: "hello", page: 1, limit: 10 } }, +}); +// query params are fully typed too — { q: string, page?: number, limit?: number } ``` In dev mode the plugin also starts the Go server, proxies API requests through Vite, watches `.go` files, and hot-reloads the frontend when types change. diff --git a/examples/greeter/main.go b/examples/greeter/main.go index d29a829..8faa07b 100644 --- a/examples/greeter/main.go +++ b/examples/greeter/main.go @@ -15,18 +15,38 @@ type Greeting struct { Hello string `json:"hello"` } -func greet(r *http.Request, body *Person) (*Greeting, error) { - if body.Name != "frank" { +func greet(r *http.Request, in *Person) (*Greeting, error) { + if in.Name != "frank" { return nil, shiftapi.Error(http.StatusBadRequest, "wrong name, I only greet frank") } - return &Greeting{Hello: body.Name}, nil + return &Greeting{Hello: in.Name}, nil +} + +type SearchQuery struct { + Q string `query:"q" validate:"required"` + Page int `query:"page" validate:"min=1"` + Limit int `query:"limit" validate:"min=1,max=100"` +} + +type SearchResult struct { + Query string `json:"query"` + Page int `json:"page"` + Limit int `json:"limit"` +} + +func search(r *http.Request, in SearchQuery) (*SearchResult, error) { + return &SearchResult{ + Query: in.Q, + Page: in.Page, + Limit: in.Limit, + }, nil } type Status struct { OK bool `json:"ok"` } -func health(r *http.Request) (*Status, error) { +func health(r *http.Request, _ struct{}) (*Status, error) { return &Status{OK: true}, nil } @@ -43,6 +63,14 @@ func main() { }), ) + shiftapi.Get(api, "/search", search, + shiftapi.WithRouteInfo(shiftapi.RouteInfo{ + Summary: "Search for things", + Description: "Search with typed query parameters", + Tags: []string{"search"}, + }), + ) + shiftapi.Get(api, "/health", health, shiftapi.WithRouteInfo(shiftapi.RouteInfo{ Summary: "Health check", diff --git a/handler.go b/handler.go index bff3adf..a827de3 100644 --- a/handler.go +++ b/handler.go @@ -5,37 +5,51 @@ import ( "errors" "log" "net/http" + "reflect" ) -// HandlerFunc is a typed handler for methods without a request body (GET, DELETE, HEAD, etc.). -type HandlerFunc[Resp any] func(r *http.Request) (Resp, error) +// HandlerFunc is a typed handler for routes. +// The In struct's fields are discriminated by struct tags: +// fields with `query:"..."` tags are parsed from query parameters, +// and fields with `json:"..."` tags (or no query tag) are parsed from the request body. +// For routes without input, use struct{} as the In type. +type HandlerFunc[In, Resp any] func(r *http.Request, in In) (Resp, error) -// HandlerFuncWithBody is a typed handler for methods with a request body (POST, PUT, PATCH, etc.). -type HandlerFuncWithBody[Body, Resp any] func(r *http.Request, body Body) (Resp, error) - -func adapt[Resp any](fn HandlerFunc[Resp], status int) http.HandlerFunc { +func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any) error, hasQuery, hasBody bool) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - resp, err := fn(r) - if err != nil { - writeError(w, err) - return + var in In + rv := reflect.ValueOf(&in).Elem() + + // JSON-decode body if there are body fields + if hasBody { + if err := json.NewDecoder(r.Body).Decode(&in); err != nil { + writeError(w, Error(http.StatusBadRequest, "invalid request body")) + return + } + // Re-point rv after decode (in case In is a pointer that was nil) + rv = reflect.ValueOf(&in).Elem() } - writeJSON(w, status, resp) - } -} -func adaptWithBody[Body, Resp any](fn HandlerFuncWithBody[Body, Resp], status int, validate func(any) error) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - var body Body - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - writeError(w, Error(http.StatusBadRequest, "invalid request body")) - return + // Reset any query-tagged fields that body decode may have + // inadvertently set, so they only come from URL query params. + if hasBody && hasQuery { + resetQueryFields(rv) } - if err := validate(body); err != nil { + + // Parse query params if there are query fields + if hasQuery { + if err := parseQueryInto(rv, r.URL.Query()); err != nil { + writeError(w, Error(http.StatusBadRequest, err.Error())) + return + } + } + + if err := validate(in); err != nil { writeError(w, err) return } - resp, err := fn(r, body) + + resp, err := fn(r, in) if err != nil { writeError(w, err) return diff --git a/handlerFuncs.go b/handlerFuncs.go index 0f0dec6..5b7810e 100644 --- a/handlerFuncs.go +++ b/handlerFuncs.go @@ -6,93 +6,94 @@ import ( "reflect" ) -func registerRoute[Resp any]( +func registerRoute[In, Resp any]( api *API, method string, path string, - fn HandlerFunc[Resp], + fn HandlerFunc[In, Resp], options ...RouteOption, ) { cfg := applyRouteOptions(options) - var resp Resp - outType := reflect.TypeOf(resp) - - if err := api.updateSchema(method, path, nil, outType, cfg.info, cfg.status); err != nil { - panic(fmt.Sprintf("shiftapi: schema generation failed for %s %s: %v", method, path, err)) + var in In + inType := reflect.TypeOf(in) + // Dereference pointer to get the underlying struct type + rawInType := inType + for rawInType != nil && rawInType.Kind() == reflect.Pointer { + rawInType = rawInType.Elem() } - pattern := fmt.Sprintf("%s %s", method, path) - api.mux.HandleFunc(pattern, adapt(fn, cfg.status)) -} + hasQuery, hasBody := partitionFields(rawInType) -func registerRouteWithBody[Body, Resp any]( - api *API, - method string, - path string, - fn HandlerFuncWithBody[Body, Resp], - options ...RouteOption, -) { - cfg := applyRouteOptions(options) + var queryType reflect.Type + if hasQuery { + queryType = rawInType + } + // POST/PUT/PATCH conventionally carry a request body, so always attempt + // body decode for these methods — even when the input is struct{}. + // This means Post(api, path, func(r, _ struct{}) ...) requires at least "{}". + methodRequiresBody := method == http.MethodPost || method == http.MethodPut || method == http.MethodPatch + decodeBody := hasBody || methodRequiresBody + + var bodyType reflect.Type + if hasBody { + bodyType = inType + } else if methodRequiresBody { + bodyType = rawInType + } - var body Body - inType := reflect.TypeOf(body) var resp Resp outType := reflect.TypeOf(resp) - if err := api.updateSchema(method, path, inType, outType, cfg.info, cfg.status); err != nil { + if err := api.updateSchema(method, path, queryType, bodyType, outType, cfg.info, cfg.status); err != nil { panic(fmt.Sprintf("shiftapi: schema generation failed for %s %s: %v", method, path, err)) } pattern := fmt.Sprintf("%s %s", method, path) - api.mux.HandleFunc(pattern, adaptWithBody(fn, cfg.status, api.validateBody)) + api.mux.HandleFunc(pattern, adapt(fn, cfg.status, api.validateBody, hasQuery, decodeBody)) } -// No-body methods - // Get registers a GET handler. -func Get[Resp any](api *API, path string, fn HandlerFunc[Resp], options ...RouteOption) { +func Get[In, Resp any](api *API, path string, fn HandlerFunc[In, Resp], options ...RouteOption) { registerRoute(api, http.MethodGet, path, fn, options...) } +// Post registers a POST handler. +func Post[In, Resp any](api *API, path string, fn HandlerFunc[In, Resp], options ...RouteOption) { + registerRoute(api, http.MethodPost, path, fn, options...) +} + +// Put registers a PUT handler. +func Put[In, Resp any](api *API, path string, fn HandlerFunc[In, Resp], options ...RouteOption) { + registerRoute(api, http.MethodPut, path, fn, options...) +} + +// Patch registers a PATCH handler. +func Patch[In, Resp any](api *API, path string, fn HandlerFunc[In, Resp], options ...RouteOption) { + registerRoute(api, http.MethodPatch, path, fn, options...) +} + // Delete registers a DELETE handler. -func Delete[Resp any](api *API, path string, fn HandlerFunc[Resp], options ...RouteOption) { +func Delete[In, Resp any](api *API, path string, fn HandlerFunc[In, Resp], options ...RouteOption) { registerRoute(api, http.MethodDelete, path, fn, options...) } // Head registers a HEAD handler. -func Head[Resp any](api *API, path string, fn HandlerFunc[Resp], options ...RouteOption) { +func Head[In, Resp any](api *API, path string, fn HandlerFunc[In, Resp], options ...RouteOption) { registerRoute(api, http.MethodHead, path, fn, options...) } // Options registers an OPTIONS handler. -func Options[Resp any](api *API, path string, fn HandlerFunc[Resp], options ...RouteOption) { +func Options[In, Resp any](api *API, path string, fn HandlerFunc[In, Resp], options ...RouteOption) { registerRoute(api, http.MethodOptions, path, fn, options...) } // Trace registers a TRACE handler. -func Trace[Resp any](api *API, path string, fn HandlerFunc[Resp], options ...RouteOption) { +func Trace[In, Resp any](api *API, path string, fn HandlerFunc[In, Resp], options ...RouteOption) { registerRoute(api, http.MethodTrace, path, fn, options...) } -// Body methods - -// Post registers a POST handler. -func Post[Body, Resp any](api *API, path string, fn HandlerFuncWithBody[Body, Resp], options ...RouteOption) { - registerRouteWithBody(api, http.MethodPost, path, fn, options...) -} - -// Put registers a PUT handler. -func Put[Body, Resp any](api *API, path string, fn HandlerFuncWithBody[Body, Resp], options ...RouteOption) { - registerRouteWithBody(api, http.MethodPut, path, fn, options...) -} - -// Patch registers a PATCH handler. -func Patch[Body, Resp any](api *API, path string, fn HandlerFuncWithBody[Body, Resp], options ...RouteOption) { - registerRouteWithBody(api, http.MethodPatch, path, fn, options...) -} - // Connect registers a CONNECT handler. -func Connect[Resp any](api *API, path string, fn HandlerFunc[Resp], options ...RouteOption) { +func Connect[In, Resp any](api *API, path string, fn HandlerFunc[In, Resp], options ...RouteOption) { registerRoute(api, http.MethodConnect, path, fn, options...) } diff --git a/packages/create-shiftapi/templates/base/internal/server/server.go b/packages/create-shiftapi/templates/base/internal/server/server.go index a2dca8b..673a82c 100644 --- a/packages/create-shiftapi/templates/base/internal/server/server.go +++ b/packages/create-shiftapi/templates/base/internal/server/server.go @@ -15,15 +15,15 @@ type EchoResponse struct { Message string `json:"message"` } -func echo(r *http.Request, body *EchoRequest) (*EchoResponse, error) { - return &EchoResponse{Message: body.Message}, nil +func echo(r *http.Request, in *EchoRequest) (*EchoResponse, error) { + return &EchoResponse{Message: in.Message}, nil } type Status struct { OK bool `json:"ok"` } -func health(r *http.Request) (*Status, error) { +func health(r *http.Request, _ struct{}) (*Status, error) { return &Status{OK: true}, nil } diff --git a/query.go b/query.go new file mode 100644 index 0000000..d113f83 --- /dev/null +++ b/query.go @@ -0,0 +1,187 @@ +package shiftapi + +import ( + "fmt" + "net/url" + "reflect" + "strconv" + "strings" +) + +// hasQueryTag returns true if the struct field has a `query` tag. +func hasQueryTag(f reflect.StructField) bool { + return f.Tag.Get("query") != "" +} + +// partitionFields inspects a struct type and reports whether it contains +// query-tagged fields and/or body (json-tagged or untagged non-query) fields. +func partitionFields(t reflect.Type) (hasQuery, hasBody bool) { + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return false, false + } + for i := range t.NumField() { + f := t.Field(i) + if !f.IsExported() { + continue + } + if hasQueryTag(f) { + hasQuery = true + } else { + // Any exported field without a query tag is a body field + jsonTag := f.Tag.Get("json") + if jsonTag == "-" { + continue + } + hasBody = true + } + } + return +} + +// resetQueryFields zeros out any query-tagged fields on a struct value. +// This is called after body decode so that query-tagged fields are only +// populated by parseQueryInto, not by JSON keys that happen to match. +func resetQueryFields(rv reflect.Value) { + for rv.Kind() == reflect.Pointer { + rv = rv.Elem() + } + if rv.Kind() != reflect.Struct { + return + } + rt := rv.Type() + for i := range rt.NumField() { + f := rt.Field(i) + if f.IsExported() && hasQueryTag(f) { + rv.Field(i).SetZero() + } + } +} + +// parseQueryInto populates query-tagged fields on an existing struct value +// from URL query parameters. Non-query fields are left untouched. +func parseQueryInto(rv reflect.Value, values url.Values) error { + for rv.Kind() == reflect.Pointer { + if rv.IsNil() { + rv.Set(reflect.New(rv.Type().Elem())) + } + rv = rv.Elem() + } + + rt := rv.Type() + if rt.Kind() != reflect.Struct { + return fmt.Errorf("query type must be a struct, got %s", rt.Kind()) + } + + for i := range rt.NumField() { + field := rt.Field(i) + if !field.IsExported() || !hasQueryTag(field) { + continue + } + + name := queryFieldName(field) + fv := rv.Field(i) + ft := field.Type + + // Handle pointer fields (optional params) + if ft.Kind() == reflect.Pointer { + rawValues, exists := values[name] + if !exists || len(rawValues) == 0 { + continue + } + ptr := reflect.New(ft.Elem()) + if err := setScalarValue(ptr.Elem(), rawValues[0]); err != nil { + return &queryParseError{Field: name, Err: err} + } + fv.Set(ptr) + continue + } + + // Handle slice fields + if ft.Kind() == reflect.Slice { + rawValues, exists := values[name] + if !exists || len(rawValues) == 0 { + continue + } + elemType := ft.Elem() + slice := reflect.MakeSlice(ft, len(rawValues), len(rawValues)) + for j, raw := range rawValues { + elem := reflect.New(elemType).Elem() + if err := setScalarValue(elem, raw); err != nil { + return &queryParseError{Field: name, Err: err} + } + slice.Index(j).Set(elem) + } + fv.Set(slice) + continue + } + + // Handle scalar fields + raw := values.Get(name) + if raw == "" { + continue + } + if err := setScalarValue(fv, raw); err != nil { + return &queryParseError{Field: name, Err: err} + } + } + + return nil +} + +// queryFieldName returns the query parameter name for a struct field. +// The field must have a non-empty `query` tag (guaranteed by hasQueryTag). +func queryFieldName(f reflect.StructField) string { + name, _, _ := strings.Cut(f.Tag.Get("query"), ",") + if name == "" { + return f.Name + } + return name +} + +// setScalarValue parses a string and sets the value on a reflect.Value. +func setScalarValue(v reflect.Value, raw string) error { + switch v.Kind() { + case reflect.String: + v.SetString(raw) + case reflect.Bool: + b, err := strconv.ParseBool(raw) + if err != nil { + return fmt.Errorf("invalid boolean value %q", raw) + } + v.SetBool(b) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(raw, 10, v.Type().Bits()) + if err != nil { + return fmt.Errorf("invalid integer value %q", raw) + } + v.SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + n, err := strconv.ParseUint(raw, 10, v.Type().Bits()) + if err != nil { + return fmt.Errorf("invalid unsigned integer value %q", raw) + } + v.SetUint(n) + case reflect.Float32, reflect.Float64: + n, err := strconv.ParseFloat(raw, v.Type().Bits()) + if err != nil { + return fmt.Errorf("invalid float value %q", raw) + } + v.SetFloat(n) + default: + return fmt.Errorf("unsupported query parameter type %s", v.Kind()) + } + return nil +} + +// queryParseError is returned when a query parameter cannot be parsed. +type queryParseError struct { + Field string + Err error +} + +func (e *queryParseError) Error() string { + return fmt.Sprintf("invalid query parameter %q: %v", e.Field, e.Err) +} diff --git a/schema.go b/schema.go index 5e75923..4d6f852 100644 --- a/schema.go +++ b/schema.go @@ -12,7 +12,7 @@ import ( var pathParamRe = regexp.MustCompile(`\{([^}]+)\}`) -func (a *API) updateSchema(method, path string, inType, outType reflect.Type, info *RouteInfo, status int) error { +func (a *API) updateSchema(method, path string, queryType, inType, outType reflect.Type, info *RouteInfo, status int) error { op := &openapi3.Operation{ OperationID: operationID(method, path), Responses: openapi3.NewResponses(), @@ -34,6 +34,15 @@ func (a *API) updateSchema(method, path string, inType, outType reflect.Type, in }) } + // Query parameters + if queryType != nil { + queryParams, err := a.generateQueryParams(queryType) + if err != nil { + return err + } + op.Parameters = append(op.Parameters, queryParams...) + } + // Response schema statusStr := fmt.Sprintf("%d", status) outSchema, err := a.generateSchemaRef(outType) @@ -112,20 +121,44 @@ func (a *API) updateSchema(method, path string, inType, outType reflect.Type, in return err } if inSchema != nil { - content := make(map[string]*openapi3.MediaType) - content["application/json"] = &openapi3.MediaType{ - Schema: &openapi3.SchemaRef{ - Ref: fmt.Sprintf("#/components/schemas/%s", inSchema.Ref), - }, - } - op.RequestBody = &openapi3.RequestBodyRef{ - Value: &openapi3.RequestBody{ - Required: true, - Content: content, - }, - } - a.spec.Components.Schemas[inSchema.Ref] = &openapi3.SchemaRef{ - Value: inSchema.Value, + // Strip query-tagged fields from the body schema + stripQueryFields(inType, inSchema.Value) + + if len(inSchema.Value.Properties) > 0 { + // Named body schema with properties + content := make(map[string]*openapi3.MediaType) + content["application/json"] = &openapi3.MediaType{ + Schema: &openapi3.SchemaRef{ + Ref: fmt.Sprintf("#/components/schemas/%s", inSchema.Ref), + }, + } + op.RequestBody = &openapi3.RequestBodyRef{ + Value: &openapi3.RequestBody{ + Required: true, + Content: content, + }, + } + a.spec.Components.Schemas[inSchema.Ref] = &openapi3.SchemaRef{ + Value: inSchema.Value, + } + } else { + // No body fields (e.g. struct{}) — inline empty object schema. + // This happens for POST/PUT/PATCH where a body is required + // even when the input struct has no body fields. + op.RequestBody = &openapi3.RequestBodyRef{ + Value: &openapi3.RequestBody{ + Required: true, + Content: map[string]*openapi3.MediaType{ + "application/json": { + Schema: &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + Type: &openapi3.Types{"object"}, + }, + }, + }, + }, + }, + } } } } @@ -210,6 +243,116 @@ func (a *API) generateSchemaRef(t reflect.Type) (*openapi3.SchemaRef, error) { return schema, nil } +// generateQueryParams produces OpenAPI parameter definitions for a query struct type. +// Only fields with `query` tags are included. +func (a *API) generateQueryParams(t reflect.Type) ([]*openapi3.ParameterRef, error) { + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("query type must be a struct, got %s", t.Kind()) + } + + var params []*openapi3.ParameterRef + for i := range t.NumField() { + field := t.Field(i) + if !field.IsExported() { + continue + } + if !hasQueryTag(field) { + continue + } + name := queryFieldName(field) + schema := fieldToOpenAPISchema(field.Type) + + // Apply validation constraints + if err := validateSchemaCustomizer(name, field.Type, field.Tag, schema.Value); err != nil { + return nil, err + } + + required := hasRule(field.Tag.Get("validate"), "required") + + params = append(params, &openapi3.ParameterRef{ + Value: &openapi3.Parameter{ + Name: name, + In: "query", + Required: required, + Schema: schema, + }, + }) + } + return params, nil +} + +// fieldToOpenAPISchema maps a Go type to an OpenAPI schema. +func fieldToOpenAPISchema(t reflect.Type) *openapi3.SchemaRef { + // Unwrap pointer + if t.Kind() == reflect.Pointer { + return fieldToOpenAPISchema(t.Elem()) + } + + // Handle slices + if t.Kind() == reflect.Slice { + items := scalarToOpenAPISchema(t.Elem()) + return &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + Type: &openapi3.Types{"array"}, + Items: items, + }, + } + } + + return scalarToOpenAPISchema(t) +} + +// scalarToOpenAPISchema maps a scalar Go type to an OpenAPI schema. +func scalarToOpenAPISchema(t reflect.Type) *openapi3.SchemaRef { + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + switch t.Kind() { + case reflect.String: + return &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}} + case reflect.Bool: + return &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"boolean"}}} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"integer"}}} + case reflect.Float32, reflect.Float64: + return &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}} + default: + return &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}} + } +} + +// stripQueryFields removes query-tagged fields from a body schema's Properties and Required. +func stripQueryFields(t reflect.Type, schema *openapi3.Schema) { + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + if t.Kind() != reflect.Struct || schema == nil { + return + } + for i := range t.NumField() { + f := t.Field(i) + if !f.IsExported() || !hasQueryTag(f) { + continue + } + jname := jsonFieldName(f) + if jname == "" || jname == "-" { + continue + } + delete(schema.Properties, jname) + // Remove from Required slice + for j, req := range schema.Required { + if req == jname { + schema.Required = append(schema.Required[:j], schema.Required[j+1:]...) + break + } + } + } +} + func scrubRefs(s *openapi3.SchemaRef) { if s == nil || s.Value == nil || len(s.Value.Properties) == 0 { return diff --git a/serve_test.go b/serve_test.go index b7ffa79..9c8665a 100644 --- a/serve_test.go +++ b/serve_test.go @@ -15,7 +15,7 @@ func TestExportSpec(t *testing.T) { Title: "Export Test", Version: "1.0.0", })) - Get(api, "/health", func(r *http.Request) (*struct { + Get(api, "/health", func(r *http.Request, _ struct{}) (*struct { OK bool `json:"ok"` }, error) { return &struct { diff --git a/shiftapi_test.go b/shiftapi_test.go index a25482a..f73ffa7 100644 --- a/shiftapi_test.go +++ b/shiftapi_test.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "net/http/httptest" + "slices" "strings" "testing" @@ -174,7 +175,7 @@ func TestServeOpenAPISpec(t *testing.T) { Title: "Spec Test", Version: "2.0", })) - shiftapi.Get(api, "/health", func(r *http.Request) (*Status, error) { + shiftapi.Get(api, "/health", func(r *http.Request, _ struct{}) (*Status, error) { return &Status{OK: true}, nil }) @@ -240,8 +241,8 @@ func TestRootRedirectsToDocs(t *testing.T) { func TestPostHandler(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/greet", func(r *http.Request, body *Person) (*Greeting, error) { - return &Greeting{Hello: body.Name}, nil + shiftapi.Post(api, "/greet", func(r *http.Request, in *Person) (*Greeting, error) { + return &Greeting{Hello: in.Name}, nil }) resp := doRequest(t, api, http.MethodPost, "/greet", `{"name":"alice"}`) @@ -256,8 +257,8 @@ func TestPostHandler(t *testing.T) { func TestPostHandlerInvalidBody(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/greet", func(r *http.Request, body *Person) (*Greeting, error) { - return &Greeting{Hello: body.Name}, nil + shiftapi.Post(api, "/greet", func(r *http.Request, in *Person) (*Greeting, error) { + return &Greeting{Hello: in.Name}, nil }) resp := doRequest(t, api, http.MethodPost, "/greet", `not json`) @@ -268,8 +269,8 @@ func TestPostHandlerInvalidBody(t *testing.T) { func TestPostHandlerEmptyBody(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/greet", func(r *http.Request, body *Person) (*Greeting, error) { - return &Greeting{Hello: body.Name}, nil + shiftapi.Post(api, "/greet", func(r *http.Request, in *Person) (*Greeting, error) { + return &Greeting{Hello: in.Name}, nil }) resp := doRequest(t, api, http.MethodPost, "/greet", "") @@ -280,8 +281,8 @@ func TestPostHandlerEmptyBody(t *testing.T) { func TestPostHandlerEmptyJSONObject(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/person", func(r *http.Request, body *ValidatedPerson) (*ValidatedPerson, error) { - return body, nil + shiftapi.Post(api, "/person", func(r *http.Request, in *ValidatedPerson) (*ValidatedPerson, error) { + return in, nil }) resp := doRequest(t, api, http.MethodPost, "/person", `{}`) @@ -294,7 +295,7 @@ func TestPostHandlerEmptyJSONObject(t *testing.T) { func TestGetHandler(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/health", func(r *http.Request) (*Status, error) { + shiftapi.Get(api, "/health", func(r *http.Request, _ struct{}) (*Status, error) { return &Status{OK: true}, nil }) @@ -310,7 +311,7 @@ func TestGetHandler(t *testing.T) { func TestGetHandlerWithPathParam(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/items/{id}", func(r *http.Request) (*Item, error) { + shiftapi.Get(api, "/items/{id}", func(r *http.Request, _ struct{}) (*Item, error) { return &Item{ID: r.PathValue("id"), Name: "widget"}, nil }) @@ -331,9 +332,9 @@ func TestGetHandlerWithPathParam(t *testing.T) { func TestPutHandler(t *testing.T) { api := newTestAPI(t) - shiftapi.Put(api, "/items/{id}", func(r *http.Request, body *Item) (*Item, error) { - body.ID = r.PathValue("id") - return body, nil + shiftapi.Put(api, "/items/{id}", func(r *http.Request, in *Item) (*Item, error) { + in.ID = r.PathValue("id") + return in, nil }) resp := doRequest(t, api, http.MethodPut, "/items/42", `{"name":"updated"}`) @@ -353,9 +354,9 @@ func TestPutHandler(t *testing.T) { func TestPatchHandler(t *testing.T) { api := newTestAPI(t) - shiftapi.Patch(api, "/items/{id}", func(r *http.Request, body *Item) (*Item, error) { - body.ID = r.PathValue("id") - return body, nil + shiftapi.Patch(api, "/items/{id}", func(r *http.Request, in *Item) (*Item, error) { + in.ID = r.PathValue("id") + return in, nil }) resp := doRequest(t, api, http.MethodPatch, "/items/99", `{"name":"patched"}`) @@ -372,7 +373,7 @@ func TestPatchHandler(t *testing.T) { func TestDeleteHandler(t *testing.T) { api := newTestAPI(t) - shiftapi.Delete(api, "/items/{id}", func(r *http.Request) (*Empty, error) { + shiftapi.Delete(api, "/items/{id}", func(r *http.Request, _ struct{}) (*Empty, error) { return &Empty{}, nil }) @@ -386,7 +387,7 @@ func TestDeleteHandler(t *testing.T) { func TestHeadHandler(t *testing.T) { api := newTestAPI(t) - shiftapi.Head(api, "/ping", func(r *http.Request) (*Empty, error) { + shiftapi.Head(api, "/ping", func(r *http.Request, _ struct{}) (*Empty, error) { return &Empty{}, nil }) @@ -400,7 +401,7 @@ func TestHeadHandler(t *testing.T) { func TestOptionsHandler(t *testing.T) { api := newTestAPI(t) - shiftapi.Options(api, "/items", func(r *http.Request) (*Empty, error) { + shiftapi.Options(api, "/items", func(r *http.Request, _ struct{}) (*Empty, error) { return &Empty{}, nil }) @@ -414,7 +415,7 @@ func TestOptionsHandler(t *testing.T) { func TestTraceHandler(t *testing.T) { api := newTestAPI(t) - shiftapi.Trace(api, "/debug", func(r *http.Request) (*Empty, error) { + shiftapi.Trace(api, "/debug", func(r *http.Request, _ struct{}) (*Empty, error) { return &Empty{}, nil }) @@ -428,7 +429,7 @@ func TestTraceHandler(t *testing.T) { func TestConnectHandler(t *testing.T) { api := newTestAPI(t) - shiftapi.Connect(api, "/tunnel", func(r *http.Request) (*Empty, error) { + shiftapi.Connect(api, "/tunnel", func(r *http.Request, _ struct{}) (*Empty, error) { return &Empty{}, nil }) @@ -442,7 +443,7 @@ func TestConnectHandler(t *testing.T) { func TestAPIErrorReturnsCorrectStatusCode(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/fail", func(r *http.Request) (*Empty, error) { + shiftapi.Get(api, "/fail", func(r *http.Request, _ struct{}) (*Empty, error) { return nil, shiftapi.Error(http.StatusNotFound, "not found") }) @@ -458,7 +459,7 @@ func TestAPIErrorReturnsCorrectStatusCode(t *testing.T) { func TestAPIErrorReturnsJSON(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/fail", func(r *http.Request, body *Person) (*Greeting, error) { + shiftapi.Post(api, "/fail", func(r *http.Request, in *Person) (*Greeting, error) { return nil, shiftapi.Error(http.StatusUnprocessableEntity, "invalid data") }) @@ -473,7 +474,7 @@ func TestAPIErrorReturnsJSON(t *testing.T) { func TestGenericErrorReturns500(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/boom", func(r *http.Request) (*Empty, error) { + shiftapi.Get(api, "/boom", func(r *http.Request, _ struct{}) (*Empty, error) { return nil, errors.New("something broke") }) @@ -506,9 +507,9 @@ func TestAPIErrorMessage(t *testing.T) { func TestWithStatusCustomCode(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/items", func(r *http.Request, body *Item) (*Item, error) { - body.ID = "new-id" - return body, nil + shiftapi.Post(api, "/items", func(r *http.Request, in *Item) (*Item, error) { + in.ID = "new-id" + return in, nil }, shiftapi.WithStatus(http.StatusCreated)) resp := doRequest(t, api, http.MethodPost, "/items", `{"name":"widget"}`) @@ -519,7 +520,7 @@ func TestWithStatusCustomCode(t *testing.T) { func TestWithStatusOnGetHandler(t *testing.T) { api := newTestAPI(t) - shiftapi.Delete(api, "/items/{id}", func(r *http.Request) (*Empty, error) { + shiftapi.Delete(api, "/items/{id}", func(r *http.Request, _ struct{}) (*Empty, error) { return &Empty{}, nil }, shiftapi.WithStatus(http.StatusNoContent)) @@ -533,8 +534,8 @@ func TestWithStatusOnGetHandler(t *testing.T) { func TestWithRouteInfoInSpec(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/greet", func(r *http.Request, body *Person) (*Greeting, error) { - return &Greeting{Hello: body.Name}, nil + shiftapi.Post(api, "/greet", func(r *http.Request, in *Person) (*Greeting, error) { + return &Greeting{Hello: in.Name}, nil }, shiftapi.WithRouteInfo(shiftapi.RouteInfo{ Summary: "Greet someone", Description: "Greets a person by name", @@ -564,7 +565,7 @@ func TestWithRouteInfoInSpec(t *testing.T) { func TestSpecHasPath(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/health", func(r *http.Request) (*Status, error) { + shiftapi.Get(api, "/health", func(r *http.Request, _ struct{}) (*Status, error) { return &Status{OK: true}, nil }) @@ -576,7 +577,7 @@ func TestSpecHasPath(t *testing.T) { func TestSpecGetHasNoRequestBody(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/health", func(r *http.Request) (*Status, error) { + shiftapi.Get(api, "/health", func(r *http.Request, _ struct{}) (*Status, error) { return &Status{OK: true}, nil }) @@ -592,8 +593,8 @@ func TestSpecGetHasNoRequestBody(t *testing.T) { func TestSpecPostHasRequestBody(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/greet", func(r *http.Request, body *Person) (*Greeting, error) { - return &Greeting{Hello: body.Name}, nil + shiftapi.Post(api, "/greet", func(r *http.Request, in *Person) (*Greeting, error) { + return &Greeting{Hello: in.Name}, nil }) spec := api.Spec() @@ -608,8 +609,8 @@ func TestSpecPostHasRequestBody(t *testing.T) { func TestSpecRequestBodyIsRequired(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/greet", func(r *http.Request, body *Person) (*Greeting, error) { - return &Greeting{Hello: body.Name}, nil + shiftapi.Post(api, "/greet", func(r *http.Request, in *Person) (*Greeting, error) { + return &Greeting{Hello: in.Name}, nil }) spec := api.Spec() @@ -623,9 +624,144 @@ func TestSpecRequestBodyIsRequired(t *testing.T) { } } +// --- Empty body behavior for body-carrying methods --- + +func TestPostNoInputRequiresBody(t *testing.T) { + api := newTestAPI(t) + shiftapi.Post(api, "/trigger", func(r *http.Request, _ struct{}) (*Status, error) { + return &Status{OK: true}, nil + }) + + // Empty body should be rejected + resp := doRequest(t, api, http.MethodPost, "/trigger", "") + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400 for POST without body, got %d", resp.StatusCode) + } + + // Empty JSON object should be accepted + resp2 := doRequest(t, api, http.MethodPost, "/trigger", `{}`) + if resp2.StatusCode != http.StatusOK { + t.Fatalf("expected 200 for POST with {}, got %d", resp2.StatusCode) + } +} + +func TestPutNoInputRequiresBody(t *testing.T) { + api := newTestAPI(t) + shiftapi.Put(api, "/items/{id}", func(r *http.Request, _ struct{}) (*Empty, error) { + return &Empty{}, nil + }) + + resp := doRequest(t, api, http.MethodPut, "/items/1", "") + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400 for PUT without body, got %d", resp.StatusCode) + } + + resp2 := doRequest(t, api, http.MethodPut, "/items/1", `{}`) + if resp2.StatusCode != http.StatusOK { + t.Fatalf("expected 200 for PUT with {}, got %d", resp2.StatusCode) + } +} + +func TestPatchNoInputRequiresBody(t *testing.T) { + api := newTestAPI(t) + shiftapi.Patch(api, "/items/{id}", func(r *http.Request, _ struct{}) (*Empty, error) { + return &Empty{}, nil + }) + + resp := doRequest(t, api, http.MethodPatch, "/items/1", "") + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400 for PATCH without body, got %d", resp.StatusCode) + } + + resp2 := doRequest(t, api, http.MethodPatch, "/items/1", `{}`) + if resp2.StatusCode != http.StatusOK { + t.Fatalf("expected 200 for PATCH with {}, got %d", resp2.StatusCode) + } +} + +func TestGetNoInputDoesNotRequireBody(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/health", func(r *http.Request, _ struct{}) (*Status, error) { + return &Status{OK: true}, nil + }) + + // GET without body should succeed + resp := doRequest(t, api, http.MethodGet, "/health", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200 for GET without body, got %d", resp.StatusCode) + } +} + +func TestDeleteNoInputDoesNotRequireBody(t *testing.T) { + api := newTestAPI(t) + shiftapi.Delete(api, "/items/{id}", func(r *http.Request, _ struct{}) (*Empty, error) { + return &Empty{}, nil + }) + + // DELETE without body should succeed + resp := doRequest(t, api, http.MethodDelete, "/items/1", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200 for DELETE without body, got %d", resp.StatusCode) + } +} + +// --- Spec: empty body on body-carrying methods --- + +func TestSpecPostNoInputHasEmptyRequestBody(t *testing.T) { + api := newTestAPI(t) + shiftapi.Post(api, "/trigger", func(r *http.Request, _ struct{}) (*Status, error) { + return &Status{OK: true}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/trigger").Post + if op.RequestBody == nil { + t.Fatal("POST with no input should still have a request body in the spec") + } + if !op.RequestBody.Value.Required { + t.Error("request body should be required") + } + content := op.RequestBody.Value.Content["application/json"] + if content == nil { + t.Fatal("expected application/json content") + } + if !content.Schema.Value.Type.Is("object") { + t.Errorf("expected empty object schema, got %v", content.Schema.Value.Type) + } + if len(content.Schema.Value.Properties) != 0 { + t.Errorf("expected 0 properties, got %d", len(content.Schema.Value.Properties)) + } +} + +func TestSpecPutNoInputHasEmptyRequestBody(t *testing.T) { + api := newTestAPI(t) + shiftapi.Put(api, "/items/{id}", func(r *http.Request, _ struct{}) (*Empty, error) { + return &Empty{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/items/{id}").Put + if op.RequestBody == nil { + t.Fatal("PUT with no input should still have a request body in the spec") + } +} + +func TestSpecGetNoInputHasNoRequestBody(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/health", func(r *http.Request, _ struct{}) (*Status, error) { + return &Status{OK: true}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/health").Get + if op.RequestBody != nil { + t.Error("GET with no input should not have a request body in the spec") + } +} + func TestSpecDeleteHasNoRequestBody(t *testing.T) { api := newTestAPI(t) - shiftapi.Delete(api, "/items/{id}", func(r *http.Request) (*Empty, error) { + shiftapi.Delete(api, "/items/{id}", func(r *http.Request, _ struct{}) (*Empty, error) { return &Empty{}, nil }) @@ -641,7 +777,7 @@ func TestSpecDeleteHasNoRequestBody(t *testing.T) { func TestSpecHasResponseSchema(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/health", func(r *http.Request) (*Status, error) { + shiftapi.Get(api, "/health", func(r *http.Request, _ struct{}) (*Status, error) { return &Status{OK: true}, nil }) @@ -658,7 +794,7 @@ func TestSpecHasResponseSchema(t *testing.T) { func TestSpecResponseDescriptionUsesStatusText(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/health", func(r *http.Request) (*Status, error) { + shiftapi.Get(api, "/health", func(r *http.Request, _ struct{}) (*Status, error) { return &Status{OK: true}, nil }) @@ -671,8 +807,8 @@ func TestSpecResponseDescriptionUsesStatusText(t *testing.T) { func TestSpecWithStatusUsesCorrectCodeInSpec(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/items", func(r *http.Request, body *Item) (*Item, error) { - return body, nil + shiftapi.Post(api, "/items", func(r *http.Request, in *Item) (*Item, error) { + return in, nil }, shiftapi.WithStatus(http.StatusCreated)) spec := api.Spec() @@ -687,8 +823,8 @@ func TestSpecWithStatusUsesCorrectCodeInSpec(t *testing.T) { func TestSpecComponentSchemasPopulated(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/greet", func(r *http.Request, body *Person) (*Greeting, error) { - return &Greeting{Hello: body.Name}, nil + shiftapi.Post(api, "/greet", func(r *http.Request, in *Person) (*Greeting, error) { + return &Greeting{Hello: in.Name}, nil }) spec := api.Spec() @@ -699,11 +835,11 @@ func TestSpecComponentSchemasPopulated(t *testing.T) { func TestSpecMultipleMethodsOnSamePath(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/items", func(r *http.Request) (*[]Item, error) { + shiftapi.Get(api, "/items", func(r *http.Request, _ struct{}) (*[]Item, error) { return &[]Item{}, nil }) - shiftapi.Post(api, "/items", func(r *http.Request, body *Item) (*Item, error) { - return body, nil + shiftapi.Post(api, "/items", func(r *http.Request, in *Item) (*Item, error) { + return in, nil }) spec := api.Spec() @@ -730,7 +866,7 @@ func TestSpecOpenAPIVersion(t *testing.T) { func TestSpecPathParametersDocumented(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/users/{id}", func(r *http.Request) (*Item, error) { + shiftapi.Get(api, "/users/{id}", func(r *http.Request, _ struct{}) (*Item, error) { return &Item{ID: r.PathValue("id")}, nil }) @@ -753,7 +889,7 @@ func TestSpecPathParametersDocumented(t *testing.T) { func TestSpecMultiplePathParameters(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/orgs/{orgId}/users/{userId}", func(r *http.Request) (*Item, error) { + shiftapi.Get(api, "/orgs/{orgId}/users/{userId}", func(r *http.Request, _ struct{}) (*Item, error) { return &Item{}, nil }) @@ -772,7 +908,7 @@ func TestSpecMultiplePathParameters(t *testing.T) { func TestSpecNoPathParametersWhenNoneInPath(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/health", func(r *http.Request) (*Status, error) { + shiftapi.Get(api, "/health", func(r *http.Request, _ struct{}) (*Status, error) { return &Status{OK: true}, nil }) @@ -803,19 +939,19 @@ func TestSpecOperationID(t *testing.T) { api := newTestAPI(t) switch tc.method { case "GET": - shiftapi.Get(api, tc.path, func(r *http.Request) (*Empty, error) { + shiftapi.Get(api, tc.path, func(r *http.Request, _ struct{}) (*Empty, error) { return &Empty{}, nil }) case "POST": - shiftapi.Post(api, tc.path, func(r *http.Request, body *Empty) (*Empty, error) { + shiftapi.Post(api, tc.path, func(r *http.Request, in *Empty) (*Empty, error) { return &Empty{}, nil }) case "PUT": - shiftapi.Put(api, tc.path, func(r *http.Request, body *Empty) (*Empty, error) { + shiftapi.Put(api, tc.path, func(r *http.Request, in *Empty) (*Empty, error) { return &Empty{}, nil }) case "DELETE": - shiftapi.Delete(api, tc.path, func(r *http.Request) (*Empty, error) { + shiftapi.Delete(api, tc.path, func(r *http.Request, _ struct{}) (*Empty, error) { return &Empty{}, nil }) } @@ -844,7 +980,7 @@ func TestSpecOperationID(t *testing.T) { func TestSpecHasDefaultErrorResponse(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/health", func(r *http.Request) (*Status, error) { + shiftapi.Get(api, "/health", func(r *http.Request, _ struct{}) (*Status, error) { return &Status{OK: true}, nil }) @@ -869,8 +1005,8 @@ func TestSpecHasDefaultErrorResponse(t *testing.T) { func TestSpecDefaultErrorResponseOnPost(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/items", func(r *http.Request, body *Item) (*Item, error) { - return body, nil + shiftapi.Post(api, "/items", func(r *http.Request, in *Item) (*Item, error) { + return in, nil }) spec := api.Spec() @@ -888,7 +1024,7 @@ func TestAPIImplementsHTTPHandler(t *testing.T) { func TestHTTPTestServerCompatibility(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/ping", func(r *http.Request) (*Status, error) { + shiftapi.Get(api, "/ping", func(r *http.Request, _ struct{}) (*Status, error) { return &Status{OK: true}, nil }) @@ -914,7 +1050,7 @@ func TestHTTPTestServerCompatibility(t *testing.T) { func TestMiddlewareCompatibility(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/health", func(r *http.Request) (*Status, error) { + shiftapi.Get(api, "/health", func(r *http.Request, _ struct{}) (*Status, error) { return &Status{OK: true}, nil }) @@ -943,7 +1079,7 @@ func addHeaderMiddleware(key, value string) func(http.Handler) http.Handler { func TestMountUnderPrefix(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/health", func(r *http.Request) (*Status, error) { + shiftapi.Get(api, "/health", func(r *http.Request, _ struct{}) (*Status, error) { return &Status{OK: true}, nil }) @@ -964,7 +1100,7 @@ func TestMountUnderPrefix(t *testing.T) { func TestHandlerAccessesHeaders(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/echo-header", func(r *http.Request) (*map[string]string, error) { + shiftapi.Get(api, "/echo-header", func(r *http.Request, _ struct{}) (*map[string]string, error) { return &map[string]string{ "value": r.Header.Get("X-Test"), }, nil @@ -987,7 +1123,7 @@ func TestHandlerAccessesHeaders(t *testing.T) { func TestHandlerAccessesQueryParams(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/search", func(r *http.Request) (*map[string]string, error) { + shiftapi.Get(api, "/search", func(r *http.Request, _ struct{}) (*map[string]string, error) { return &map[string]string{ "q": r.URL.Query().Get("q"), }, nil @@ -1005,7 +1141,7 @@ func TestHandlerAccessesQueryParams(t *testing.T) { func TestHandlerAccessesContext(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/ctx", func(r *http.Request) (*Status, error) { + shiftapi.Get(api, "/ctx", func(r *http.Request, _ struct{}) (*Status, error) { if r.Context() == nil { return nil, errors.New("context is nil") } @@ -1022,7 +1158,7 @@ func TestHandlerAccessesContext(t *testing.T) { func TestSuccessResponseHasJSONContentType(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/test", func(r *http.Request) (*Status, error) { + shiftapi.Get(api, "/test", func(r *http.Request, _ struct{}) (*Status, error) { return &Status{OK: true}, nil }) @@ -1034,7 +1170,7 @@ func TestSuccessResponseHasJSONContentType(t *testing.T) { func TestErrorResponseFromAPIErrorHasJSONContentType(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/fail", func(r *http.Request) (*Empty, error) { + shiftapi.Get(api, "/fail", func(r *http.Request, _ struct{}) (*Empty, error) { return nil, shiftapi.Error(http.StatusBadRequest, "bad") }) @@ -1049,13 +1185,13 @@ func TestErrorResponseFromAPIErrorHasJSONContentType(t *testing.T) { func TestMultipleRoutes(t *testing.T) { api := newTestAPI(t) - shiftapi.Get(api, "/a", func(r *http.Request) (*map[string]string, error) { + shiftapi.Get(api, "/a", func(r *http.Request, _ struct{}) (*map[string]string, error) { return &map[string]string{"route": "a"}, nil }) - shiftapi.Get(api, "/b", func(r *http.Request) (*map[string]string, error) { + shiftapi.Get(api, "/b", func(r *http.Request, _ struct{}) (*map[string]string, error) { return &map[string]string{"route": "b"}, nil }) - shiftapi.Post(api, "/c", func(r *http.Request, body *Empty) (*map[string]string, error) { + shiftapi.Post(api, "/c", func(r *http.Request, in *Empty) (*map[string]string, error) { return &map[string]string{"route": "c"}, nil }) @@ -1095,7 +1231,7 @@ func TestSpecReturnsLiveObject(t *testing.T) { api := newTestAPI(t) before := len(api.Spec().Paths.InMatchingOrder()) - shiftapi.Get(api, "/new-route", func(r *http.Request) (*Empty, error) { + shiftapi.Get(api, "/new-route", func(r *http.Request, _ struct{}) (*Empty, error) { return &Empty{}, nil }) @@ -1129,8 +1265,8 @@ type NoValidateBody struct { func TestValidationRequiredFieldMissing(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/person", func(r *http.Request, body *ValidatedPerson) (*ValidatedPerson, error) { - return body, nil + shiftapi.Post(api, "/person", func(r *http.Request, in *ValidatedPerson) (*ValidatedPerson, error) { + return in, nil }) resp := doRequest(t, api, http.MethodPost, "/person", `{"email":"test@example.com"}`) @@ -1162,8 +1298,8 @@ func TestValidationRequiredFieldMissing(t *testing.T) { func TestValidationEmailInvalid(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/person", func(r *http.Request, body *ValidatedPerson) (*ValidatedPerson, error) { - return body, nil + shiftapi.Post(api, "/person", func(r *http.Request, in *ValidatedPerson) (*ValidatedPerson, error) { + return in, nil }) resp := doRequest(t, api, http.MethodPost, "/person", `{"name":"alice","email":"not-an-email"}`) @@ -1192,8 +1328,8 @@ func TestValidationEmailInvalid(t *testing.T) { func TestValidationMinMaxViolated(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/minmax", func(r *http.Request, body *MinMaxBody) (*MinMaxBody, error) { - return body, nil + shiftapi.Post(api, "/minmax", func(r *http.Request, in *MinMaxBody) (*MinMaxBody, error) { + return in, nil }) resp := doRequest(t, api, http.MethodPost, "/minmax", `{"age":0,"name":"a"}`) @@ -1204,8 +1340,8 @@ func TestValidationMinMaxViolated(t *testing.T) { func TestValidationValidPayloadPassesThrough(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/person", func(r *http.Request, body *ValidatedPerson) (*ValidatedPerson, error) { - return body, nil + shiftapi.Post(api, "/person", func(r *http.Request, in *ValidatedPerson) (*ValidatedPerson, error) { + return in, nil }) resp := doRequest(t, api, http.MethodPost, "/person", `{"name":"alice","email":"alice@example.com"}`) @@ -1220,8 +1356,8 @@ func TestValidationValidPayloadPassesThrough(t *testing.T) { func TestValidationNoTagsPassThrough(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/noval", func(r *http.Request, body *NoValidateBody) (*NoValidateBody, error) { - return body, nil + shiftapi.Post(api, "/noval", func(r *http.Request, in *NoValidateBody) (*NoValidateBody, error) { + return in, nil }) resp := doRequest(t, api, http.MethodPost, "/noval", `{"foo":"bar"}`) @@ -1233,8 +1369,8 @@ func TestValidationNoTagsPassThrough(t *testing.T) { func TestWithValidatorCustomInstance(t *testing.T) { v := validator.New() api := shiftapi.New(shiftapi.WithValidator(v)) - shiftapi.Post(api, "/person", func(r *http.Request, body *ValidatedPerson) (*ValidatedPerson, error) { - return body, nil + shiftapi.Post(api, "/person", func(r *http.Request, in *ValidatedPerson) (*ValidatedPerson, error) { + return in, nil }) // Missing required fields should still fail @@ -1262,8 +1398,8 @@ func TestValidationErrorSatisfiesErrorsAs(t *testing.T) { func TestSpecRequiredFieldInParentSchema(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/person", func(r *http.Request, body *ValidatedPerson) (*ValidatedPerson, error) { - return body, nil + shiftapi.Post(api, "/person", func(r *http.Request, in *ValidatedPerson) (*ValidatedPerson, error) { + return in, nil }) spec := api.Spec() @@ -1273,18 +1409,18 @@ func TestSpecRequiredFieldInParentSchema(t *testing.T) { t.Fatal("expected ValidatedPerson in component schemas") } schema := schemaRef.Value - if !contains(schema.Required, "name") { + if !slices.Contains(schema.Required, "name") { t.Errorf("expected 'name' in required, got %v", schema.Required) } - if !contains(schema.Required, "email") { + if !slices.Contains(schema.Required, "email") { t.Errorf("expected 'email' in required, got %v", schema.Required) } } func TestSpecEmailFormatSet(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/person", func(r *http.Request, body *ValidatedPerson) (*ValidatedPerson, error) { - return body, nil + shiftapi.Post(api, "/person", func(r *http.Request, in *ValidatedPerson) (*ValidatedPerson, error) { + return in, nil }) spec := api.Spec() @@ -1300,8 +1436,8 @@ func TestSpecEmailFormatSet(t *testing.T) { func TestSpecMinMaxOnNumber(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/minmax", func(r *http.Request, body *MinMaxBody) (*MinMaxBody, error) { - return body, nil + shiftapi.Post(api, "/minmax", func(r *http.Request, in *MinMaxBody) (*MinMaxBody, error) { + return in, nil }) spec := api.Spec() @@ -1320,8 +1456,8 @@ func TestSpecMinMaxOnNumber(t *testing.T) { func TestSpecMinMaxOnString(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/minmax", func(r *http.Request, body *MinMaxBody) (*MinMaxBody, error) { - return body, nil + shiftapi.Post(api, "/minmax", func(r *http.Request, in *MinMaxBody) (*MinMaxBody, error) { + return in, nil }) spec := api.Spec() @@ -1340,8 +1476,8 @@ func TestSpecMinMaxOnString(t *testing.T) { func TestSpecEnumOnOneOf(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/oneof", func(r *http.Request, body *OneOfBody) (*OneOfBody, error) { - return body, nil + shiftapi.Post(api, "/oneof", func(r *http.Request, in *OneOfBody) (*OneOfBody, error) { + return in, nil }) spec := api.Spec() @@ -1375,8 +1511,8 @@ type PersonWithAddress struct { func TestValidationNestedStructValid(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/person-addr", func(r *http.Request, body *PersonWithAddress) (*PersonWithAddress, error) { - return body, nil + shiftapi.Post(api, "/person-addr", func(r *http.Request, in *PersonWithAddress) (*PersonWithAddress, error) { + return in, nil }) resp := doRequest(t, api, http.MethodPost, "/person-addr", `{"name":"alice","address":{"street":"123 Main St","city":"Springfield"}}`) @@ -1387,8 +1523,8 @@ func TestValidationNestedStructValid(t *testing.T) { func TestValidationNestedStructMissingFields(t *testing.T) { api := newTestAPI(t) - shiftapi.Post(api, "/person-addr", func(r *http.Request, body *PersonWithAddress) (*PersonWithAddress, error) { - return body, nil + shiftapi.Post(api, "/person-addr", func(r *http.Request, in *PersonWithAddress) (*PersonWithAddress, error) { + return in, nil }) resp := doRequest(t, api, http.MethodPost, "/person-addr", `{"name":"alice","address":{}}`) @@ -1397,11 +1533,963 @@ func TestValidationNestedStructMissingFields(t *testing.T) { } } -func contains(slice []string, item string) bool { - for _, s := range slice { - if s == item { - return true +// --- Query parameter test types --- + +type SearchQuery struct { + Q string `query:"q" validate:"required"` + Page int `query:"page" validate:"min=1"` + Limit int `query:"limit" validate:"min=1,max=100"` +} + +type SearchResult struct { + Query string `json:"query"` + Page int `json:"page"` + Limit int `json:"limit"` +} + +type TagQuery struct { + Tags []string `query:"tag"` +} + +type TagResult struct { + Tags []string `json:"tags"` +} + +type OptionalQuery struct { + Name string `query:"name"` + Debug *bool `query:"debug"` + Limit *int `query:"limit"` +} + +type OptionalResult struct { + Name string `json:"name"` + HasDebug bool `json:"has_debug"` + Debug bool `json:"debug"` + HasLimit bool `json:"has_limit"` + Limit int `json:"limit"` +} + +type FilterQuery struct { + Status string `query:"status" validate:"oneof=active inactive pending"` +} + +// --- Query parameter runtime tests --- + +func TestGetWithQueryBasic(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/search", func(r *http.Request, in SearchQuery) (*SearchResult, error) { + return &SearchResult{Query: in.Q, Page: in.Page, Limit: in.Limit}, nil + }) + + resp := doRequest(t, api, http.MethodGet, "/search?q=hello&page=2&limit=10", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[SearchResult](t, resp) + if result.Query != "hello" { + t.Errorf("expected Query=hello, got %q", result.Query) + } + if result.Page != 2 { + t.Errorf("expected Page=2, got %d", result.Page) + } + if result.Limit != 10 { + t.Errorf("expected Limit=10, got %d", result.Limit) + } +} + +func TestGetWithQueryMissingRequired(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/search", func(r *http.Request, in SearchQuery) (*SearchResult, error) { + return &SearchResult{}, nil + }) + + // Missing required "q" param + resp := doRequest(t, api, http.MethodGet, "/search?page=1&limit=10", "") + if resp.StatusCode != http.StatusUnprocessableEntity { + t.Fatalf("expected 422, got %d", resp.StatusCode) + } +} + +func TestGetWithQueryInvalidType(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/search", func(r *http.Request, in SearchQuery) (*SearchResult, error) { + return &SearchResult{}, nil + }) + + // "page" should be an int, not "abc" + resp := doRequest(t, api, http.MethodGet, "/search?q=test&page=abc&limit=10", "") + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } +} + +func TestGetWithQuerySliceParams(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/tags", func(r *http.Request, in TagQuery) (*TagResult, error) { + return &TagResult{Tags: in.Tags}, nil + }) + + resp := doRequest(t, api, http.MethodGet, "/tags?tag=a&tag=b&tag=c", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[TagResult](t, resp) + if len(result.Tags) != 3 { + t.Fatalf("expected 3 tags, got %d", len(result.Tags)) + } + expected := []string{"a", "b", "c"} + for i, tag := range result.Tags { + if tag != expected[i] { + t.Errorf("expected tag[%d]=%q, got %q", i, expected[i], tag) + } + } +} + +func TestGetWithQueryOptionalPointer(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/optional", func(r *http.Request, in OptionalQuery) (*OptionalResult, error) { + result := &OptionalResult{Name: in.Name} + if in.Debug != nil { + result.HasDebug = true + result.Debug = *in.Debug + } + if in.Limit != nil { + result.HasLimit = true + result.Limit = *in.Limit + } + return result, nil + }) + + // With optional params + resp := doRequest(t, api, http.MethodGet, "/optional?name=test&debug=true&limit=50", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[OptionalResult](t, resp) + if !result.HasDebug || !result.Debug { + t.Error("expected debug=true") + } + if !result.HasLimit || result.Limit != 50 { + t.Error("expected limit=50") + } + + // Without optional params + resp2 := doRequest(t, api, http.MethodGet, "/optional?name=test", "") + if resp2.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp2.StatusCode) + } + result2 := decodeJSON[OptionalResult](t, resp2) + if result2.HasDebug { + t.Error("expected debug to be absent") + } + if result2.HasLimit { + t.Error("expected limit to be absent") + } +} + +func TestPostWithQueryAndBody(t *testing.T) { + api := newTestAPI(t) + + type CreateInput struct { + DryRun bool `query:"dry_run"` + Name string `json:"name"` + } + type CreateResult struct { + Name string `json:"name"` + DryRun bool `json:"dry_run"` + } + + shiftapi.Post(api, "/items", func(r *http.Request, in CreateInput) (*CreateResult, error) { + return &CreateResult{Name: in.Name, DryRun: in.DryRun}, nil + }) + + resp := doRequest(t, api, http.MethodPost, "/items?dry_run=true", `{"name":"widget"}`) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[CreateResult](t, resp) + if result.Name != "widget" { + t.Errorf("expected Name=widget, got %q", result.Name) + } + if !result.DryRun { + t.Error("expected DryRun=true") + } +} + +// --- Query/JSON tag interop tests --- + +func TestQueryFieldInBodyIsIgnored(t *testing.T) { + api := newTestAPI(t) + + type Input struct { + DryRun bool `query:"dry_run"` + Name string `json:"name"` + } + + shiftapi.Post(api, "/items", func(r *http.Request, in Input) (*map[string]any, error) { + return &map[string]any{"name": in.Name, "dry_run": in.DryRun}, nil + }) + + // Use the Go field name "DryRun" which json.Decode would match + // case-insensitively — resetQueryFields must clear it. + resp := doRequest(t, api, http.MethodPost, "/items", `{"name":"widget","DryRun":true}`) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[map[string]any](t, resp) + if result["name"] != "widget" { + t.Errorf("expected name=widget, got %v", result["name"]) + } + if result["dry_run"] != false { + t.Errorf("expected dry_run=false (query field must not be set from body), got %v", result["dry_run"]) + } +} + +func TestBodyFieldInQueryIsIgnored(t *testing.T) { + api := newTestAPI(t) + + type Input struct { + DryRun bool `query:"dry_run"` + Name string `json:"name"` + } + + shiftapi.Post(api, "/items", func(r *http.Request, in Input) (*map[string]any, error) { + return &map[string]any{"name": in.Name, "dry_run": in.DryRun}, nil + }) + + // Send name in query but NOT in body — should remain empty + resp := doRequest(t, api, http.MethodPost, "/items?name=sneaky&dry_run=true", `{"name":"widget"}`) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[map[string]any](t, resp) + if result["name"] != "widget" { + t.Errorf("expected name from body (widget), got %v", result["name"]) + } + if result["dry_run"] != true { + t.Errorf("expected dry_run=true from query, got %v", result["dry_run"]) + } +} + +func TestFieldWithBothJsonAndQueryTagsUsesQuery(t *testing.T) { + api := newTestAPI(t) + + type Input struct { + Mode string `query:"mode" json:"mode"` + Name string `json:"name"` + } + + shiftapi.Post(api, "/items", func(r *http.Request, in Input) (*map[string]string, error) { + return &map[string]string{"mode": in.Mode, "name": in.Name}, nil + }) + + // Send conflicting values: "body_mode" in body, "query_mode" in query + resp := doRequest(t, api, http.MethodPost, "/items?mode=query_mode", `{"name":"widget","mode":"body_mode"}`) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[map[string]string](t, resp) + // Query parsing runs after body decode, so query value should win + if result["mode"] != "query_mode" { + t.Errorf("expected mode=query_mode (query overrides body), got %q", result["mode"]) + } + if result["name"] != "widget" { + t.Errorf("expected name=widget, got %q", result["name"]) + } +} + +func TestSpecMixedStructBodyExcludesQueryFields(t *testing.T) { + api := newTestAPI(t) + + type Input struct { + DryRun bool `query:"dry_run"` + Name string `json:"name"` + ID string `json:"id"` + } + + shiftapi.Post(api, "/items", func(r *http.Request, in Input) (*Empty, error) { + return &Empty{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/items").Post + + // Query params should include dry_run only + queryParams := 0 + for _, p := range op.Parameters { + if p.Value.In == "query" { + queryParams++ + if p.Value.Name != "dry_run" { + t.Errorf("unexpected query param %q", p.Value.Name) + } + } + } + if queryParams != 1 { + t.Errorf("expected 1 query parameter, got %d", queryParams) + } + + // Body schema should include name and id but NOT dry_run + if op.RequestBody == nil { + t.Fatal("expected request body") + } + bodyRef := op.RequestBody.Value.Content["application/json"].Schema.Ref + schemaName := bodyRef[len("#/components/schemas/"):] + bodySchema := spec.Components.Schemas[schemaName].Value + if bodySchema.Properties["name"] == nil { + t.Error("expected 'name' in body schema") + } + if bodySchema.Properties["id"] == nil { + t.Error("expected 'id' in body schema") + } + if bodySchema.Properties["dry_run"] != nil { + t.Error("'dry_run' should NOT be in body schema (it's a query param)") + } +} + +func TestSpecMixedStructQueryExcludesBodyFields(t *testing.T) { + api := newTestAPI(t) + + type Input struct { + Filter string `query:"filter"` + Sort string `query:"sort"` + Name string `json:"name"` + } + + shiftapi.Post(api, "/items", func(r *http.Request, in Input) (*Empty, error) { + return &Empty{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/items").Post + + // Query params should only be filter and sort, not name + paramNames := map[string]bool{} + for _, p := range op.Parameters { + if p.Value.In == "query" { + paramNames[p.Value.Name] = true } } - return false + if !paramNames["filter"] { + t.Error("expected 'filter' query param") + } + if !paramNames["sort"] { + t.Error("expected 'sort' query param") + } + if paramNames["name"] { + t.Error("'name' should NOT be a query param (it's a body field)") + } +} + +func TestGetWithQueryAndPathParams(t *testing.T) { + api := newTestAPI(t) + + type ItemQuery struct { + Fields string `query:"fields"` + } + + shiftapi.Get(api, "/items/{id}", func(r *http.Request, in ItemQuery) (*map[string]string, error) { + return &map[string]string{ + "id": r.PathValue("id"), + "fields": in.Fields, + }, nil + }) + + resp := doRequest(t, api, http.MethodGet, "/items/abc123?fields=name,price", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[map[string]string](t, resp) + if result["id"] != "abc123" { + t.Errorf("expected id=abc123, got %q", result["id"]) + } + if result["fields"] != "name,price" { + t.Errorf("expected fields=name,price, got %q", result["fields"]) + } +} + +func TestDeleteWithQuery(t *testing.T) { + api := newTestAPI(t) + + type DeleteQuery struct { + Force bool `query:"force"` + } + + shiftapi.Delete(api, "/items/{id}", func(r *http.Request, in DeleteQuery) (*map[string]any, error) { + return &map[string]any{ + "id": r.PathValue("id"), + "force": in.Force, + }, nil + }) + + resp := doRequest(t, api, http.MethodDelete, "/items/42?force=true", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } +} + +func TestGetWithQueryValidationConstraint(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/filter", func(r *http.Request, in FilterQuery) (*map[string]string, error) { + return &map[string]string{"status": in.Status}, nil + }) + + // Valid value + resp := doRequest(t, api, http.MethodGet, "/filter?status=active", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + // Invalid value -> 422 + resp2 := doRequest(t, api, http.MethodGet, "/filter?status=unknown", "") + if resp2.StatusCode != http.StatusUnprocessableEntity { + t.Fatalf("expected 422, got %d", resp2.StatusCode) + } +} + +// --- Query parameter: scalar type parsing --- + +func TestGetWithQueryBoolScalar(t *testing.T) { + api := newTestAPI(t) + + type BoolQuery struct { + Verbose bool `query:"verbose"` + } + + shiftapi.Get(api, "/logs", func(r *http.Request, in BoolQuery) (*map[string]bool, error) { + return &map[string]bool{"verbose": in.Verbose}, nil + }) + + resp := doRequest(t, api, http.MethodGet, "/logs?verbose=true", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[map[string]bool](t, resp) + if !result["verbose"] { + t.Error("expected verbose=true") + } + + // false value + resp2 := doRequest(t, api, http.MethodGet, "/logs?verbose=false", "") + if resp2.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp2.StatusCode) + } + result2 := decodeJSON[map[string]bool](t, resp2) + if result2["verbose"] { + t.Error("expected verbose=false") + } +} + +func TestGetWithQueryUint(t *testing.T) { + api := newTestAPI(t) + + type PageQuery struct { + Offset uint `query:"offset"` + Limit uint64 `query:"limit"` + } + + shiftapi.Get(api, "/pages", func(r *http.Request, in PageQuery) (*map[string]uint64, error) { + return &map[string]uint64{"offset": uint64(in.Offset), "limit": in.Limit}, nil + }) + + resp := doRequest(t, api, http.MethodGet, "/pages?offset=10&limit=100", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[map[string]float64](t, resp) + if result["offset"] != 10 { + t.Errorf("expected offset=10, got %v", result["offset"]) + } + if result["limit"] != 100 { + t.Errorf("expected limit=100, got %v", result["limit"]) + } +} + +func TestGetWithQueryFloat(t *testing.T) { + api := newTestAPI(t) + + type CoordQuery struct { + Lat float64 `query:"lat"` + Lng float32 `query:"lng"` + } + + shiftapi.Get(api, "/nearby", func(r *http.Request, in CoordQuery) (*map[string]float64, error) { + return &map[string]float64{"lat": in.Lat, "lng": float64(in.Lng)}, nil + }) + + resp := doRequest(t, api, http.MethodGet, "/nearby?lat=40.7128&lng=-74.006", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[map[string]float64](t, resp) + if result["lat"] != 40.7128 { + t.Errorf("expected lat=40.7128, got %v", result["lat"]) + } +} + +// --- Query parameter: parse errors --- + +func TestGetWithQueryInvalidBool(t *testing.T) { + api := newTestAPI(t) + + type BoolQuery struct { + Debug bool `query:"debug"` + } + + shiftapi.Get(api, "/test", func(r *http.Request, in BoolQuery) (*Empty, error) { + return &Empty{}, nil + }) + + resp := doRequest(t, api, http.MethodGet, "/test?debug=notabool", "") + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } +} + +func TestGetWithQueryInvalidUint(t *testing.T) { + api := newTestAPI(t) + + type UintQuery struct { + Count uint `query:"count"` + } + + shiftapi.Get(api, "/test", func(r *http.Request, in UintQuery) (*Empty, error) { + return &Empty{}, nil + }) + + resp := doRequest(t, api, http.MethodGet, "/test?count=-1", "") + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } +} + +func TestGetWithQueryInvalidFloat(t *testing.T) { + api := newTestAPI(t) + + type FloatQuery struct { + Score float64 `query:"score"` + } + + shiftapi.Get(api, "/test", func(r *http.Request, in FloatQuery) (*Empty, error) { + return &Empty{}, nil + }) + + resp := doRequest(t, api, http.MethodGet, "/test?score=abc", "") + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } +} + +// --- Query parameter: skip and zero-value behavior --- + +func TestGetWithQuerySkipTag(t *testing.T) { + api := newTestAPI(t) + + type SkipQuery struct { + Name string `query:"name"` + Secret string `json:"-"` + } + + shiftapi.Get(api, "/test", func(r *http.Request, in SkipQuery) (*map[string]string, error) { + return &map[string]string{"name": in.Name, "secret": in.Secret}, nil + }) + + resp := doRequest(t, api, http.MethodGet, "/test?name=alice&secret=hidden", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[map[string]string](t, resp) + if result["name"] != "alice" { + t.Errorf("expected name=alice, got %q", result["name"]) + } + if result["secret"] != "" { + t.Errorf("expected secret to be empty (skipped), got %q", result["secret"]) + } +} + +func TestGetWithQueryAbsentParamsGetZeroValues(t *testing.T) { + api := newTestAPI(t) + + type MixedQuery struct { + Name string `query:"name"` + Count int `query:"count"` + Flag bool `query:"flag"` + } + + shiftapi.Get(api, "/test", func(r *http.Request, in MixedQuery) (*map[string]any, error) { + return &map[string]any{ + "name": in.Name, + "count": in.Count, + "flag": in.Flag, + }, nil + }) + + // No query params at all — everything should be zero-valued + resp := doRequest(t, api, http.MethodGet, "/test", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[map[string]any](t, resp) + if result["name"] != "" { + t.Errorf("expected name=\"\", got %v", result["name"]) + } + if result["count"] != float64(0) { + t.Errorf("expected count=0, got %v", result["count"]) + } + if result["flag"] != false { + t.Errorf("expected flag=false, got %v", result["flag"]) + } +} + +// --- Query parameter: spec types for bool/float/uint --- + +func TestSpecQueryParamBoolType(t *testing.T) { + api := newTestAPI(t) + + type BoolQuery struct { + Debug bool `query:"debug"` + } + + shiftapi.Get(api, "/test", func(r *http.Request, in BoolQuery) (*Empty, error) { + return &Empty{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/test").Get + if len(op.Parameters) != 1 { + t.Fatalf("expected 1 parameter, got %d", len(op.Parameters)) + } + if !op.Parameters[0].Value.Schema.Value.Type.Is("boolean") { + t.Errorf("expected boolean type, got %v", op.Parameters[0].Value.Schema.Value.Type) + } +} + +func TestSpecQueryParamFloatType(t *testing.T) { + api := newTestAPI(t) + + type FloatQuery struct { + Score float64 `query:"score"` + } + + shiftapi.Get(api, "/test", func(r *http.Request, in FloatQuery) (*Empty, error) { + return &Empty{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/test").Get + if len(op.Parameters) != 1 { + t.Fatalf("expected 1 parameter, got %d", len(op.Parameters)) + } + if !op.Parameters[0].Value.Schema.Value.Type.Is("number") { + t.Errorf("expected number type, got %v", op.Parameters[0].Value.Schema.Value.Type) + } +} + +func TestSpecQueryParamUintType(t *testing.T) { + api := newTestAPI(t) + + type UintQuery struct { + Count uint `query:"count"` + } + + shiftapi.Get(api, "/test", func(r *http.Request, in UintQuery) (*Empty, error) { + return &Empty{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/test").Get + if len(op.Parameters) != 1 { + t.Fatalf("expected 1 parameter, got %d", len(op.Parameters)) + } + if !op.Parameters[0].Value.Schema.Value.Type.Is("integer") { + t.Errorf("expected integer type, got %v", op.Parameters[0].Value.Schema.Value.Type) + } +} + +func TestSpecQuerySkipTagNotDocumented(t *testing.T) { + api := newTestAPI(t) + + type SkipQuery struct { + Name string `query:"name"` + Secret string `json:"-"` + } + + shiftapi.Get(api, "/test", func(r *http.Request, in SkipQuery) (*Empty, error) { + return &Empty{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/test").Get + if len(op.Parameters) != 1 { + t.Fatalf("expected 1 parameter (secret should be skipped), got %d", len(op.Parameters)) + } + if op.Parameters[0].Value.Name != "name" { + t.Errorf("expected parameter 'name', got %q", op.Parameters[0].Value.Name) + } +} + +func TestSpecQueryParamOptionalPointerNotRequired(t *testing.T) { + api := newTestAPI(t) + + shiftapi.Get(api, "/optional", func(r *http.Request, in OptionalQuery) (*Empty, error) { + return &Empty{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/optional").Get + + for _, p := range op.Parameters { + if p.Value.Required { + t.Errorf("expected parameter %q to not be required (pointer type)", p.Value.Name) + } + } +} + +// --- Query parameter spec tests --- + +func TestSpecQueryParamsDocumented(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/search", func(r *http.Request, in SearchQuery) (*SearchResult, error) { + return &SearchResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/search").Get + if op == nil { + t.Fatal("expected GET operation on /search") + } + + // Should have 3 query params: q, page, limit + queryParams := 0 + for _, p := range op.Parameters { + if p.Value.In == "query" { + queryParams++ + } + } + if queryParams != 3 { + t.Fatalf("expected 3 query parameters, got %d", queryParams) + } +} + +func TestSpecQueryParamTypes(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/search", func(r *http.Request, in SearchQuery) (*SearchResult, error) { + return &SearchResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/search").Get + + paramByName := make(map[string]*openapi3.Parameter) + for _, p := range op.Parameters { + paramByName[p.Value.Name] = p.Value + } + + // q is a string + if q, ok := paramByName["q"]; !ok { + t.Fatal("expected 'q' query parameter") + } else if !q.Schema.Value.Type.Is("string") { + t.Errorf("expected q type 'string', got %v", q.Schema.Value.Type) + } + + // page is an integer + if page, ok := paramByName["page"]; !ok { + t.Fatal("expected 'page' query parameter") + } else if !page.Schema.Value.Type.Is("integer") { + t.Errorf("expected page type 'integer', got %v", page.Schema.Value.Type) + } + + // limit is an integer + if limit, ok := paramByName["limit"]; !ok { + t.Fatal("expected 'limit' query parameter") + } else if !limit.Schema.Value.Type.Is("integer") { + t.Errorf("expected limit type 'integer', got %v", limit.Schema.Value.Type) + } +} + +func TestSpecQueryParamRequired(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/search", func(r *http.Request, in SearchQuery) (*SearchResult, error) { + return &SearchResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/search").Get + + paramByName := make(map[string]*openapi3.Parameter) + for _, p := range op.Parameters { + paramByName[p.Value.Name] = p.Value + } + + // q has validate:"required" so it should be required + if !paramByName["q"].Required { + t.Error("expected 'q' to be required") + } + // page does not have required tag + if paramByName["page"].Required { + t.Error("expected 'page' to not be required") + } +} + +func TestSpecQueryParamValidationConstraints(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/search", func(r *http.Request, in SearchQuery) (*SearchResult, error) { + return &SearchResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/search").Get + + paramByName := make(map[string]*openapi3.Parameter) + for _, p := range op.Parameters { + paramByName[p.Value.Name] = p.Value + } + + // page has min=1 + pageSchema := paramByName["page"].Schema.Value + if pageSchema.Min == nil || *pageSchema.Min != 1 { + t.Errorf("expected page minimum 1, got %v", pageSchema.Min) + } + + // limit has min=1,max=100 + limitSchema := paramByName["limit"].Schema.Value + if limitSchema.Min == nil || *limitSchema.Min != 1 { + t.Errorf("expected limit minimum 1, got %v", limitSchema.Min) + } + if limitSchema.Max == nil || *limitSchema.Max != 100 { + t.Errorf("expected limit maximum 100, got %v", limitSchema.Max) + } +} + +func TestSpecQueryParamEnum(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/filter", func(r *http.Request, in FilterQuery) (*Empty, error) { + return &Empty{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/filter").Get + + var statusParam *openapi3.Parameter + for _, p := range op.Parameters { + if p.Value.Name == "status" { + statusParam = p.Value + break + } + } + if statusParam == nil { + t.Fatal("expected 'status' query parameter") + } + if len(statusParam.Schema.Value.Enum) != 3 { + t.Fatalf("expected 3 enum values, got %d", len(statusParam.Schema.Value.Enum)) + } +} + +func TestSpecQueryParamSliceType(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/tags", func(r *http.Request, in TagQuery) (*TagResult, error) { + return &TagResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/tags").Get + + var tagParam *openapi3.Parameter + for _, p := range op.Parameters { + if p.Value.Name == "tag" { + tagParam = p.Value + break + } + } + if tagParam == nil { + t.Fatal("expected 'tag' query parameter") + } + if !tagParam.Schema.Value.Type.Is("array") { + t.Errorf("expected tag type 'array', got %v", tagParam.Schema.Value.Type) + } + if tagParam.Schema.Value.Items == nil || !tagParam.Schema.Value.Items.Value.Type.Is("string") { + t.Error("expected tag items type 'string'") + } +} + +func TestSpecQueryParamsCombinedWithPathParams(t *testing.T) { + api := newTestAPI(t) + + type ItemQuery struct { + Fields string `query:"fields"` + } + + shiftapi.Get(api, "/items/{id}", func(r *http.Request, in ItemQuery) (*Item, error) { + return &Item{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/items/{id}").Get + + pathParams := 0 + queryParams := 0 + for _, p := range op.Parameters { + switch p.Value.In { + case "path": + pathParams++ + case "query": + queryParams++ + } + } + if pathParams != 1 { + t.Errorf("expected 1 path parameter, got %d", pathParams) + } + if queryParams != 1 { + t.Errorf("expected 1 query parameter, got %d", queryParams) + } +} + +func TestSpecPostWithQueryHasQueryParamsAndBody(t *testing.T) { + api := newTestAPI(t) + + type CreateInput struct { + DryRun bool `query:"dry_run"` + Name string `json:"name"` + ID string `json:"id"` + } + + shiftapi.Post(api, "/items", func(r *http.Request, in CreateInput) (*Item, error) { + return &Item{ID: in.ID, Name: in.Name}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/items").Post + if op == nil { + t.Fatal("expected POST operation on /items") + } + + // Should have query params + queryParams := 0 + for _, p := range op.Parameters { + if p.Value.In == "query" { + queryParams++ + } + } + if queryParams != 1 { + t.Errorf("expected 1 query parameter, got %d", queryParams) + } + + // Should also have a request body + if op.RequestBody == nil { + t.Error("expected request body on POST with query and body") + } +} + +// --- Query-only input should not have request body --- + +func TestSpecQueryOnlyInputHasNoRequestBody(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/search", func(r *http.Request, in SearchQuery) (*SearchResult, error) { + return &SearchResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/search").Get + if op.RequestBody != nil { + t.Error("GET with query-only input should not have a request body in the spec") + } } diff --git a/validate.go b/validate.go index b8b91fa..5c301af 100644 --- a/validate.go +++ b/validate.go @@ -38,7 +38,7 @@ func WithValidator(v *validator.Validate) Option { // It dereferences pointers and skips non-struct types. func validateStruct(v *validator.Validate, val any) error { rv := reflect.ValueOf(val) - for rv.Kind() == reflect.Ptr { + for rv.Kind() == reflect.Pointer { if rv.IsNil() { return nil } @@ -110,8 +110,8 @@ func validateSchemaCustomizer(name string, t reflect.Type, tag reflect.StructTag return nil } - rules := strings.Split(validateTag, ",") - for _, rule := range rules { + rules := strings.SplitSeq(validateTag, ",") + for rule := range rules { rule = strings.TrimSpace(rule) key, param, _ := strings.Cut(rule, "=") @@ -219,7 +219,7 @@ func applyLen(t reflect.Type, schema *openapi3.Schema, param string) { } func derefKind(t reflect.Type) reflect.Kind { - for t.Kind() == reflect.Ptr { + for t.Kind() == reflect.Pointer { t = t.Elem() } return t.Kind() @@ -246,7 +246,7 @@ func isSliceKind(k reflect.Kind) bool { // applyRequired walks struct fields and adds JSON names to schema.Required // for fields that have validate:"required". func applyRequired(t reflect.Type, schema *openapi3.Schema) { - for t.Kind() == reflect.Ptr { + for t.Kind() == reflect.Pointer { t = t.Elem() } if t.Kind() != reflect.Struct { @@ -269,7 +269,7 @@ func applyRequired(t reflect.Type, schema *openapi3.Schema) { // hasRule checks whether a comma-separated validate tag contains the given rule name. func hasRule(tag, rule string) bool { - for _, r := range strings.Split(tag, ",") { + for r := range strings.SplitSeq(tag, ",") { key, _, _ := strings.Cut(strings.TrimSpace(r), "=") if key == rule { return true