diff --git a/src/xmpp/stream.go b/src/xmpp/stream.go index d1b37b1..d0c80f1 100644 --- a/src/xmpp/stream.go +++ b/src/xmpp/stream.go @@ -18,6 +18,8 @@ const ( type Stream struct { conn net.Conn dec *xml.Decoder + stanzaBuf string + incomingNamespace nsMap } // Create a XML stream connection. A Steam is used by an XMPP instance to @@ -31,7 +33,7 @@ func NewStream(addr string) (*Stream, error) { return nil, err } - stream := &Stream{conn, xml.NewDecoder(conn)} + stream := &Stream{conn: conn, dec: xml.NewDecoder(conn)} if err := stream.send([]byte("")); err != nil { return nil, err @@ -67,7 +69,22 @@ func (stream *Stream) SendStart(start *xml.StartElement) (*xml.StartElement, err } // Read and return start of incoming doc. - return nextStartElement(stream.dec) + rstart, err := nextStartElement(stream.dec) + if err != nil { + return nil, err + } + + // Collect top-level namespaces. + stream.incomingNamespace = make(nsMap) + for _, attr := range rstart.Attr { + if attr.Name.Space == "xmlns" { + stream.incomingNamespace[attr.Value] = attr.Name.Local + } else if attr.Name.Space == "" && attr.Name.Local == "xmlns" { + stream.incomingNamespace[attr.Value] = "" + } + } + + return rstart, nil } // Send a stanza. Used to write a complete, top-level element. @@ -92,13 +109,23 @@ func (stream *Stream) send(b []byte) error { // Bad things are very likely to happen if a call to Next() is successful but // you don't actually decode or skip the element. func (stream *Stream) Next(match *xml.Name) (*xml.StartElement, error) { + start, err := nextStartElement(stream.dec) if err != nil { return nil, err } + + if xml, err := collectElement(stream.dec, start, stream.incomingNamespace); err != nil { + return nil, err + } else { + stream.stanzaBuf = xml + } + log.Println("recv:", stream.stanzaBuf) + if match != nil && start.Name != *match { return nil, fmt.Errorf("Expected %s, got %s", *match, start.Name) } + return start, nil } @@ -125,7 +152,8 @@ func nextStartElement(dec *xml.Decoder) (*xml.StartElement, error) { // Skip reads tokens until it reaches the end element of the most recent start // element that has already been read. func (stream *Stream) Skip() error { - return stream.dec.Skip() + return nil +// return stream.dec.Skip() } // Decode the next stanza. Works like xml.Unmarshal but reads from the stream's @@ -138,7 +166,8 @@ func (stream *Stream) Decode(v interface{}) error { // xml.Decoder.DecodeElement. func (stream *Stream) DecodeElement(v interface{}, start *xml.StartElement) error { - // Explicity lookup next start element to ensure stream is validated. + // Explicity lookup next start element to ensure stream is validated, + // stanza is logged, etc. if start == nil { if se, err := stream.Next(nil); err != nil { return err @@ -147,5 +176,57 @@ func (stream *Stream) DecodeElement(v interface{}, start *xml.StartElement) erro } } - return stream.dec.DecodeElement(v, start) + return xml.Unmarshal([]byte(stream.stanzaBuf), v) +// return stream.dec.DecodeElement(v, start) +} + +// Collect the element with the start that's already been consumed into a +// buffer. Namespaces are munged so the buffer can be correctly parsed outside +// the context of the stream. This is used for logging the received data. +func collectElement(dec *xml.Decoder, start *xml.StartElement, nsmap nsMap) (string, error) { + + var collector struct { + XML []byte `xml:",innerxml"` + } + + if err := dec.DecodeElement(&collector, start); err != nil { + return "", err + } + + name := start.Name + attrs := start.Attr + + // Map the element's namespace. + if ns, ok := nsmap[name.Space]; ok { + // Element's namespace is one of the stream namespaces. Update the + // element's namespace and add the namespace to the element's attrs. + attrs = append(attrs, xml.Attr{xml.Name{"xmlns", ns}, name.Space}) + name = xml.Name{ns, name.Local} + } else { + // Check that Go's xml package hasn't duplicated the default ns as the + // element name's space. If so, clear it. + for _, attr := range attrs { + if attr.Name == (xml.Name{"", "xmlns"}) { + if name.Space == attr.Value { + name = xml.Name{"", start.Name.Local} + } + break + } + } + } + + start = &xml.StartElement{name, attrs} + + buf := new(bytes.Buffer) + if err := writeXMLStartElement(buf, start); err != nil { + return "", err + } + if _, err := buf.Write(collector.XML); err != nil { + return "", err + } + if err := writeXMLEndElement(buf, &xml.EndElement{start.Name}); err != nil { + return "", err + } + + return buf.String(), nil } diff --git a/src/xmpp/xml.go b/src/xmpp/xml.go index d8e22a6..fbbb9c8 100644 --- a/src/xmpp/xml.go +++ b/src/xmpp/xml.go @@ -6,6 +6,8 @@ import ( "io" ) +type nsMap map[string]string + // Write an xml.StartElement. func writeXMLStartElement(w io.Writer, start *xml.StartElement) error { if _, err := w.Write([]byte{'<'}); err != nil { @@ -28,6 +30,20 @@ func writeXMLStartElement(w io.Writer, start *xml.StartElement) error { return nil } +// Write an xml.StartElement. +func writeXMLEndElement(w io.Writer, end *xml.EndElement) error { + if _, err := w.Write([]byte{'<', '/'}); err != nil { + return err + } + if err := writeXMLName(w, end.Name); err != nil { + return err + } + if _, err := w.Write([]byte{'>'}); err != nil { + return err + } + return nil +} + // Write a xml.Name. func writeXMLName(w io.Writer, name xml.Name) error { if name.Space == "" {