diff --git a/lib/forms/forms.go b/lib/forms/forms.go new file mode 100644 index 0000000..cd51af8 --- /dev/null +++ b/lib/forms/forms.go @@ -0,0 +1,198 @@ +package forms + +import ( + "errors" + "fmt" + "log" + "net/http" + "reflect" + "strconv" + "strings" + +func Unmarshal(r *http.Request, v any) (err error) { + if u, ok := v.(Unmarshaler); ok { + return u.UnmarshalForm(r) + } + + defer func() { + if r := recover(); r != nil { + err = errors.Join(ErrReflectPanic, fmt.Errorf("Panic recovered: %#v", r)) + } + }() + + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Pointer { + rv = rv.Elem() + } + rt := rv.Type() + + for i := 0; i < rv.NumField(); i++ { + ft := rt.Field(i) + fv := rv.FieldByName(ft.Name) + + log.Print(ft.Name) + + if !fv.CanSet() { + continue + } + + // TODO: Support embedded fields + if ft.Anonymous { + continue + } + + var tv string + if t := ft.Tag.Get("form"); t != "" { + tv = t + } else if t = ft.Tag.Get("query"); t != "" { + tv = t + } else { + tv = ft.Name + } + + tvs := strings.Split(tv, ",") + + name := tvs[0] + required := false + defaultv := "" + + for _, v := range tvs { + if v == "required" { + required = true + } else if strings.HasPrefix(v, "default=") { + defaultv = strings.TrimPrefix(v, "default=") + } + } + + qv := r.FormValue(name) + if qv == "" { + if defaultv != "" { + qv = defaultv + } else if required { + return &ErrMissingRequiredValue{name} + } else { + continue + } + } + + if err := setFieldValue(fv, qv); errors.Is(err, &ErrInvalidValueType{}) { + e, _ := err.(*ErrInvalidValueType) + e.value = name + return e + } else if errors.Is(err, &ErrUnsuportedValueType{}) { + e, _ := err.(*ErrUnsuportedValueType) + e.value = name + return e + } else if err != nil { + return err + } + } + + return nil +} + +func setFieldValue(rv reflect.Value, v string) error { + switch rv.Kind() { + + case reflect.Pointer: + return setFieldValue(rv.Elem(), v) + + case reflect.String: + rv.SetString(v) + + case reflect.Bool: + if cv, err := strconv.ParseBool(v); err != nil { + return &ErrInvalidValueType{"bool", err, ""} + } else { + rv.SetBool(cv) + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if cv, err := strconv.Atoi(v); err != nil { + return &ErrInvalidValueType{"int", err, ""} + } else { + rv.SetInt(int64(cv)) + } + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if cv, err := strconv.Atoi(v); err != nil { + return &ErrInvalidValueType{"uint", err, ""} + } else { + rv.SetUint(uint64(cv)) + } + + case reflect.Float32, reflect.Float64: + if cv, err := strconv.ParseFloat(v, 64); err != nil { + return &ErrInvalidValueType{"float64", err, ""} + } else { + rv.SetFloat(cv) + } + + case reflect.Complex64, reflect.Complex128: + if cv, err := strconv.ParseComplex(v, 128); err != nil { + return &ErrInvalidValueType{"complex128", err, ""} + } else { + rv.SetComplex(cv) + } + + // TODO: Support strucys + // TODO: Support slices + // TODO: Support maps + default: + return &ErrUnsuportedValueType{ + []string{ + "string", + "bool", + "int", "int8", "int16", "int32", "int64", + "uint", "uint8", "uint16", "uint32", "uint64", + "float32", "float64", + "complex64", "complex64", + }, + "", + } + + } + + return nil +} + +type ErrInvalidValueType struct { + expected string + err error + value string +} + +func (e ErrInvalidValueType) Error() string { + return fmt.Sprintf( + "Value \"%s\" is a invalid type, expected type \"%s\". Got err: %s", + e.value, + e.expected, + e.err.Error(), + ) +} + +type ErrUnsuportedValueType struct { + supported []string + value string +} + +func (e ErrUnsuportedValueType) Error() string { + return fmt.Sprintf( + "Value \"%s\" is a unsupported type, supported types are: \"%s\"", + e.value, + strings.Join(e.supported, ", "), + ) +} + +type ErrMissingRequiredValue struct { + value string +} + +func (e ErrMissingRequiredValue) Error() string { + return fmt.Sprintf("Required value \"%s\" missing from query", e.value) +} + +var ( + ErrParseForm = errors.New("Failed to parse form from body or query parameters") + ErrReflectPanic = errors.New("Reflect panic while trying to parse request") +)